Skip to content

Commit 49e7840

Browse files
committed
CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion
1 parent 477a66b commit 49e7840

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2901,7 +2901,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29012901
}
29022902

29032903
//if rms norm is the B operand, then we don't handle broadcast
2904-
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2904+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
29052905
return false;
29062906
}
29072907

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2681,7 +2681,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
26812681

26822682
// if rms_norm is the B operand, then we don't handle broadcast
26832683
if (rms_norm == mul->src[1] &&
2684-
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2684+
!ggml_are_same_shape(mul->src[0], rms_norm)) {
26852685
return false;
26862686
}
26872687

0 commit comments

Comments
 (0)