Skip to content

Commit f47daed

Browse files
committed
assume mul_ptr is not null when calling fused ops, formatting changes
1 parent d2e56c5 commit f47daed

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2767,11 +2767,11 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27672767
#endif
27682768

27692769
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2770-
if(!ggml_can_fuse(cgraph, node_idx, ops)) {
2770+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
27712771
return false;
27722772
}
27732773

2774-
if(ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2774+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
27752775
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
27762776
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
27772777

ggml/src/ggml-cuda/norm.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ static __global__ void rms_norm_f32(
122122

123123
const float * mul_ptr = nullptr;
124124
if constexpr (do_multiply) {
125-
if (mul != nullptr) {
126-
mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row;
127-
}
125+
mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row;
128126
}
129127

130128
float tmp = 0.0f; // partial sum for thread in warp
@@ -154,7 +152,7 @@ static __global__ void rms_norm_f32(
154152

155153
for (int col = tid; col < ncols; col += block_size) {
156154
if constexpr (do_multiply) {
157-
dst[col] = scale * x[col] * (mul_ptr ? mul_ptr[col] : 1.0f);
155+
dst[col] = scale * x[col] * mul_ptr[col];
158156
} else {
159157
dst[col] = scale * x[col];
160158
}
@@ -335,6 +333,10 @@ static void rms_norm_mul_f32_cuda(
335333
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
336334
const float eps, cudaStream_t stream) {
337335
const dim3 blocks_num(nrows, nchannels, nsamples);
336+
if(mul == nullptr) {
337+
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
338+
return;
339+
}
338340
if (ncols < 1024) {
339341
const dim3 block_dims(WARP_SIZE, 1, 1);
340342
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample);
@@ -443,7 +445,7 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
443445
const float * src0_d = (const float *) rms_norm_src->data;
444446
const float * mul_d = nullptr;
445447

446-
if(mul_tensor->src[0] == dst) {
448+
if (mul_tensor->src[0] == dst) {
447449
mul_d = (float *) mul_tensor->src[1]->data;
448450
} else if(mul_tensor->src[1] == dst) {
449451
mul_d = (float *) mul_tensor->src[0]->data;

0 commit comments

Comments
 (0)