@@ -3885,6 +3885,7 @@ static struct ggml_tensor * ggml_rope_impl(
38853885 struct ggml_tensor * b ,
38863886 struct ggml_tensor * c ,
38873887 int n_dims ,
3888+ int sections [GGML_MROPE_SECTIONS ],
38883889 int mode ,
38893890 int n_ctx_orig ,
38903891 float freq_base ,
@@ -3898,15 +3899,19 @@ static struct ggml_tensor * ggml_rope_impl(
38983899
38993900 GGML_ASSERT (ggml_is_vector (b ));
39003901 GGML_ASSERT (b -> type == GGML_TYPE_I32 );
3901- GGML_ASSERT (a -> ne [2 ] == b -> ne [0 ]);
3902+
3903+ bool mrope_used = mode & GGML_ROPE_TYPE_MROPE ;
3904+ if (mrope_used ) {
3905+ GGML_ASSERT (a -> ne [2 ] * 4 == b -> ne [0 ]); // mrope expecting 4 position ids per token
3906+ } else {
3907+ GGML_ASSERT (a -> ne [2 ] == b -> ne [0 ]);
3908+ }
39023909
39033910 if (c ) {
39043911 GGML_ASSERT (c -> type == GGML_TYPE_F32 );
39053912 GGML_ASSERT (c -> ne [0 ] >= n_dims / 2 );
39063913 }
39073914
3908- int sections [4 ] = {0 , 0 , 0 , 0 };
3909-
39103915 struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
39113916
39123917 int32_t params [15 ] = { /*n_past*/ 0 , n_dims , mode , /*n_ctx*/ 0 , n_ctx_orig };
@@ -3916,7 +3921,10 @@ static struct ggml_tensor * ggml_rope_impl(
39163921 memcpy (params + 8 , & attn_factor , sizeof (float ));
39173922 memcpy (params + 9 , & beta_fast , sizeof (float ));
39183923 memcpy (params + 10 , & beta_slow , sizeof (float ));
3919- memcpy (params + 11 , & sections , sizeof (int )* 4 );
3924+ if (mrope_used )
3925+ memcpy (params + 11 , sections , sizeof (int32_t ) * GGML_MROPE_SECTIONS );
3926+ else
3927+ memset (params + 11 , 0 , sizeof (int32_t ) * GGML_MROPE_SECTIONS );
39203928 ggml_set_op_params (result , params , sizeof (params ));
39213929
39223930 result -> op = GGML_OP_ROPE ;
@@ -3934,7 +3942,7 @@ struct ggml_tensor * ggml_rope(
39343942 int n_dims ,
39353943 int mode ) {
39363944 return ggml_rope_impl (
3937- ctx , a , b , NULL , n_dims , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , false
3945+ ctx , a , b , NULL , n_dims , NULL , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , false
39383946 );
39393947}
39403948
@@ -3944,7 +3952,7 @@ struct ggml_tensor * ggml_rope_multi(
39443952 struct ggml_tensor * b ,
39453953 struct ggml_tensor * c ,
39463954 int n_dims ,
3947- int sections [4 ],
3955+ int sections [GGML_MROPE_SECTIONS ],
39483956 int mode ,
39493957 int n_ctx_orig ,
39503958 float freq_base ,
@@ -3953,36 +3961,31 @@ struct ggml_tensor * ggml_rope_multi(
39533961 float attn_factor ,
39543962 float beta_fast ,
39553963 float beta_slow ) {
3956- // Multimodal Rotary Position Embedding
3957- GGML_ASSERT ((mode & 1 ) == 0 && "mode & 1 == 1 is no longer supported" );
3958-
3959- GGML_ASSERT (ggml_is_vector (b ));
3960- GGML_ASSERT (b -> type == GGML_TYPE_I32 );
3961- GGML_ASSERT (a -> ne [2 ] * 4 == b -> ne [0 ]); // mrope expecting 4 position ids per token
3962-
3963- if (c ) {
3964- GGML_ASSERT (c -> type == GGML_TYPE_F32 );
3965- GGML_ASSERT (c -> ne [0 ] >= n_dims / 2 );
3966- }
3967-
3968- struct ggml_tensor * result = ggml_dup_tensor (ctx , a );
3969-
3970- int32_t params [11 + 4 ] = { /*n_past*/ 0 , n_dims , mode , /*n_ctx*/ 0 , n_ctx_orig };
3971- memcpy (params + 5 , & freq_base , sizeof (float ));
3972- memcpy (params + 6 , & freq_scale , sizeof (float ));
3973- memcpy (params + 7 , & ext_factor , sizeof (float ));
3974- memcpy (params + 8 , & attn_factor , sizeof (float ));
3975- memcpy (params + 9 , & beta_fast , sizeof (float ));
3976- memcpy (params + 10 , & beta_slow , sizeof (float ));
3977- memcpy (& params [11 ], sections , sizeof (int )* 4 );
3978- ggml_set_op_params (result , params , sizeof (params ));
3979-
3980- result -> op = GGML_OP_ROPE ;
3981- result -> src [0 ] = a ;
3982- result -> src [1 ] = b ;
3983- result -> src [2 ] = c ;
3964+ return ggml_rope_impl (
3965+ ctx , a , b , c , n_dims , sections , mode , n_ctx_orig , freq_base , freq_scale ,
3966+ ext_factor , attn_factor , beta_fast , beta_slow , false
3967+ );
3968+ }
39843969
3985- return result ;
3970+ struct ggml_tensor * ggml_rope_multi_inplace (
3971+ struct ggml_context * ctx ,
3972+ struct ggml_tensor * a ,
3973+ struct ggml_tensor * b ,
3974+ struct ggml_tensor * c ,
3975+ int n_dims ,
3976+ int sections [GGML_MROPE_SECTIONS ],
3977+ int mode ,
3978+ int n_ctx_orig ,
3979+ float freq_base ,
3980+ float freq_scale ,
3981+ float ext_factor ,
3982+ float attn_factor ,
3983+ float beta_fast ,
3984+ float beta_slow ) {
3985+ return ggml_rope_impl (
3986+ ctx , a , b , c , n_dims , sections , mode , n_ctx_orig , freq_base , freq_scale ,
3987+ ext_factor , attn_factor , beta_fast , beta_slow , true
3988+ );
39863989}
39873990
39883991struct ggml_tensor * ggml_rope_inplace (
@@ -3992,7 +3995,7 @@ struct ggml_tensor * ggml_rope_inplace(
39923995 int n_dims ,
39933996 int mode ) {
39943997 return ggml_rope_impl (
3995- ctx , a , b , NULL , n_dims , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , true
3998+ ctx , a , b , NULL , n_dims , NULL , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , true
39963999 );
39974000}
39984001
@@ -4011,7 +4014,7 @@ struct ggml_tensor * ggml_rope_ext(
40114014 float beta_fast ,
40124015 float beta_slow ) {
40134016 return ggml_rope_impl (
4014- ctx , a , b , c , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4017+ ctx , a , b , c , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40154018 ext_factor , attn_factor , beta_fast , beta_slow , false
40164019 );
40174020}
@@ -4031,7 +4034,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
40314034 float beta_fast ,
40324035 float beta_slow ) {
40334036 return ggml_rope_impl (
4034- ctx , a , b , c , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4037+ ctx , a , b , c , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40354038 ext_factor , attn_factor , beta_fast , beta_slow , true
40364039 );
40374040}
@@ -4050,7 +4053,7 @@ struct ggml_tensor * ggml_rope_custom(
40504053 float beta_fast ,
40514054 float beta_slow ) {
40524055 return ggml_rope_impl (
4053- ctx , a , b , NULL , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4056+ ctx , a , b , NULL , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40544057 ext_factor , attn_factor , beta_fast , beta_slow , false
40554058 );
40564059}
@@ -4069,7 +4072,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
40694072 float beta_fast ,
40704073 float beta_slow ) {
40714074 return ggml_rope_impl (
4072- ctx , a , b , NULL , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4075+ ctx , a , b , NULL , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
40734076 ext_factor , attn_factor , beta_fast , beta_slow , true
40744077 );
40754078}
0 commit comments