@@ -3885,6 +3885,7 @@ static struct ggml_tensor * ggml_rope_impl(
38853885 struct ggml_tensor * b ,
38863886 struct ggml_tensor * c ,
38873887 int n_dims ,
3888+ int sections [GGML_MROPE_SECTIONS ],
38883889 int mode ,
38893890 int n_ctx_orig ,
38903891 float freq_base ,
@@ -3898,15 +3899,19 @@ static struct ggml_tensor * ggml_rope_impl(
38983899
38993900 GGML_ASSERT (ggml_is_vector (b ));
39003901 GGML_ASSERT (b -> type == GGML_TYPE_I32 );
3901- GGML_ASSERT (a -> ne [2 ] == b -> ne [0 ]);
3902+
3903+ bool mrope_used = mode & GGML_ROPE_TYPE_MROPE ;
3904+ if (mrope_used ) {
3905+ GGML_ASSERT (a -> ne [2 ] * 4 == b -> ne [0 ]); // mrope expecting 4 position ids per token
3906+ } else {
3907+ GGML_ASSERT (a -> ne [2 ] == b -> ne [0 ]);
3908+ }
39023909
39033910 if (c ) {
39043911 GGML_ASSERT (c -> type == GGML_TYPE_F32 );
39053912 GGML_ASSERT (c -> ne [0 ] >= n_dims / 2 );
39063913 }
39073914
3908- int sections [4 ] = {0 , 0 , 0 , 0 };
3909-
39103915 struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
39113916
39123917 int32_t params [15 ] = { /*n_past*/ 0 , n_dims , mode , /*n_ctx*/ 0 , n_ctx_orig };
@@ -3916,7 +3921,11 @@ static struct ggml_tensor * ggml_rope_impl(
39163921 memcpy (params + 8 , & attn_factor , sizeof (float ));
39173922 memcpy (params + 9 , & beta_fast , sizeof (float ));
39183923 memcpy (params + 10 , & beta_slow , sizeof (float ));
3919- memcpy (params + 11 , & sections , sizeof (int )* 4 );
3924+ if (mrope_used ) {
3925+ memcpy (params + 11 , sections , sizeof (int32_t ) * GGML_MROPE_SECTIONS );
3926+ } else {
3927+ memset (params + 11 , 0 , sizeof (int32_t ) * GGML_MROPE_SECTIONS );
3928+ }
39203929 ggml_set_op_params (result , params , sizeof (params ));
39213930
39223931 result -> op = GGML_OP_ROPE ;
@@ -3934,7 +3943,7 @@ struct ggml_tensor * ggml_rope(
39343943 int n_dims ,
39353944 int mode ) {
39363945 return ggml_rope_impl (
3937- ctx , a , b , NULL , n_dims , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , false
3946+ ctx , a , b , NULL , n_dims , NULL , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , false
39383947 );
39393948}
39403949
@@ -3944,7 +3953,7 @@ struct ggml_tensor * ggml_rope_multi(
39443953 struct ggml_tensor * b ,
39453954 struct ggml_tensor * c ,
39463955 int n_dims ,
3947- int sections [4 ],
3956+ int sections [GGML_MROPE_SECTIONS ],
39483957 int mode ,
39493958 int n_ctx_orig ,
39503959 float freq_base ,
@@ -3953,36 +3962,31 @@ struct ggml_tensor * ggml_rope_multi(
39533962 float attn_factor ,
39543963 float beta_fast ,
39553964 float beta_slow ) {
3956- // Multimodal Rotary Position Embedding
3957- GGML_ASSERT ((mode & 1 ) == 0 && "mode & 1 == 1 is no longer supported" );
3958-
3959- GGML_ASSERT (ggml_is_vector (b ));
3960- GGML_ASSERT (b -> type == GGML_TYPE_I32 );
3961- GGML_ASSERT (a -> ne [2 ] * 4 == b -> ne [0 ]); // mrope expecting 4 position ids per token
3962-
3963- if (c ) {
3964- GGML_ASSERT (c -> type == GGML_TYPE_F32 );
3965- GGML_ASSERT (c -> ne [0 ] >= n_dims / 2 );
3966- }
3967-
3968- struct ggml_tensor * result = ggml_dup_tensor (ctx , a );
3969-
3970- int32_t params [11 + 4 ] = { /*n_past*/ 0 , n_dims , mode , /*n_ctx*/ 0 , n_ctx_orig };
3971- memcpy (params + 5 , & freq_base , sizeof (float ));
3972- memcpy (params + 6 , & freq_scale , sizeof (float ));
3973- memcpy (params + 7 , & ext_factor , sizeof (float ));
3974- memcpy (params + 8 , & attn_factor , sizeof (float ));
3975- memcpy (params + 9 , & beta_fast , sizeof (float ));
3976- memcpy (params + 10 , & beta_slow , sizeof (float ));
3977- memcpy (& params [11 ], sections , sizeof (int )* 4 );
3978- ggml_set_op_params (result , params , sizeof (params ));
3979-
3980- result -> op = GGML_OP_ROPE ;
3981- result -> src [0 ] = a ;
3982- result -> src [1 ] = b ;
3983- result -> src [2 ] = c ;
3965+ return ggml_rope_impl (
3966+ ctx , a , b , c , n_dims , sections , mode , n_ctx_orig , freq_base , freq_scale ,
3967+ ext_factor , attn_factor , beta_fast , beta_slow , false
3968+ );
3969+ }
39843970
3985- return result ;
3971+ struct ggml_tensor * ggml_rope_multi_inplace (
3972+ struct ggml_context * ctx ,
3973+ struct ggml_tensor * a ,
3974+ struct ggml_tensor * b ,
3975+ struct ggml_tensor * c ,
3976+ int n_dims ,
3977+ int sections [GGML_MROPE_SECTIONS ],
3978+ int mode ,
3979+ int n_ctx_orig ,
3980+ float freq_base ,
3981+ float freq_scale ,
3982+ float ext_factor ,
3983+ float attn_factor ,
3984+ float beta_fast ,
3985+ float beta_slow ) {
3986+ return ggml_rope_impl (
3987+ ctx , a , b , c , n_dims , sections , mode , n_ctx_orig , freq_base , freq_scale ,
3988+ ext_factor , attn_factor , beta_fast , beta_slow , true
3989+ );
39863990}
39873991
39883992struct ggml_tensor * ggml_rope_inplace (
@@ -3992,7 +3996,7 @@ struct ggml_tensor * ggml_rope_inplace(
39923996 int n_dims ,
39933997 int mode ) {
39943998 return ggml_rope_impl (
3995- ctx , a , b , NULL , n_dims , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , true
3999+ ctx , a , b , NULL , n_dims , NULL , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , true
39964000 );
39974001}
39984002
@@ -4011,7 +4015,7 @@ struct ggml_tensor * ggml_rope_ext(
40114015 float beta_fast ,
40124016 float beta_slow ) {
40134017 return ggml_rope_impl (
4014- ctx , a , b , c , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4018+ ctx , a , b , c , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40154019 ext_factor , attn_factor , beta_fast , beta_slow , false
40164020 );
40174021}
@@ -4031,7 +4035,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
40314035 float beta_fast ,
40324036 float beta_slow ) {
40334037 return ggml_rope_impl (
4034- ctx , a , b , c , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4038+ ctx , a , b , c , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40354039 ext_factor , attn_factor , beta_fast , beta_slow , true
40364040 );
40374041}
@@ -4050,7 +4054,7 @@ struct ggml_tensor * ggml_rope_custom(
40504054 float beta_fast ,
40514055 float beta_slow ) {
40524056 return ggml_rope_impl (
4053- ctx , a , b , NULL , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4057+ ctx , a , b , NULL , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40544058 ext_factor , attn_factor , beta_fast , beta_slow , false
40554059 );
40564060}
@@ -4069,7 +4073,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
40694073 float beta_fast ,
40704074 float beta_slow ) {
40714075 return ggml_rope_impl (
4072- ctx , a , b , NULL , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4076+ ctx , a , b , NULL , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40734077 ext_factor , attn_factor , beta_fast , beta_slow , true
40744078 );
40754079}
0 commit comments