@@ -122,9 +122,7 @@ static __global__ void rms_norm_f32(
122122
123123 const float * mul_ptr = nullptr ;
124124 if constexpr (do_multiply) {
125- if (mul != nullptr ) {
126- mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row;
127- }
125+ mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row;
128126 }
129127
130128 float tmp = 0 .0f ; // partial sum for thread in warp
@@ -154,7 +152,7 @@ static __global__ void rms_norm_f32(
154152
155153 for (int col = tid; col < ncols; col += block_size) {
156154 if constexpr (do_multiply) {
157- dst[col] = scale * x[col] * ( mul_ptr ? mul_ptr [col] : 1 . 0f ) ;
155+ dst[col] = scale * x[col] * mul_ptr[col];
158156 } else {
159157 dst[col] = scale * x[col];
160158 }
@@ -335,6 +333,10 @@ static void rms_norm_mul_f32_cuda(
335333 const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
336334 const float eps, cudaStream_t stream) {
337335 const dim3 blocks_num (nrows, nchannels, nsamples);
336+ if (mul == nullptr ) {
337+ rms_norm_f32_cuda (x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
338+ return ;
339+ }
338340 if (ncols < 1024 ) {
339341 const dim3 block_dims (WARP_SIZE, 1 , 1 );
340342 rms_norm_f32<WARP_SIZE, true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample);
@@ -443,7 +445,7 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
443445 const float * src0_d = (const float *) rms_norm_src->data ;
444446 const float * mul_d = nullptr ;
445447
446- if (mul_tensor->src [0 ] == dst) {
448+ if (mul_tensor->src [0 ] == dst) {
447449 mul_d = (float *) mul_tensor->src [1 ]->data ;
448450 } else if (mul_tensor->src [1 ] == dst) {
449451 mul_d = (float *) mul_tensor->src [0 ]->data ;
0 commit comments