Skip to content

Commit 84e3d8d

Browse files
committed
metal : fix rope kernels buffer check
1 parent a8d57d6 commit 84e3d8d

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ typedef struct {
251251
int32_t sect_1;
252252
int32_t sect_2;
253253
int32_t sect_3;
254+
bool src2;
254255
} ggml_metal_kargs_rope;
255256

256257
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,6 +2969,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
29692969
/* sect_1 =*/ sect_1,
29702970
/* sect_2 =*/ sect_2,
29712971
/* sect_3 =*/ sect_3,
2972+
/* src2 =*/ op->src[2] != nullptr,
29722973
};
29732974

29742975
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,7 +3748,7 @@ kernel void kernel_rope_norm(
37483748

37493749
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
37503750

3751-
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
3751+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
37523752

37533753
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
37543754

@@ -3801,7 +3801,7 @@ kernel void kernel_rope_neox(
38013801

38023802
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
38033803

3804-
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
3804+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
38053805

38063806
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
38073807

@@ -3872,7 +3872,7 @@ kernel void kernel_rope_multi(
38723872

38733873
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
38743874

3875-
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
3875+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
38763876

38773877
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
38783878

@@ -3939,7 +3939,7 @@ kernel void kernel_rope_vision(
39393939
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
39403940
// end of mrope
39413941

3942-
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
3942+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
39433943

39443944
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
39453945

0 commit comments

Comments
 (0)