@@ -3885,6 +3885,7 @@ static struct ggml_tensor * ggml_rope_impl(
3885
3885
struct ggml_tensor * b ,
3886
3886
struct ggml_tensor * c ,
3887
3887
int n_dims ,
3888
+ int sections [GGML_MROPE_SECTIONS ],
3888
3889
int mode ,
3889
3890
int n_ctx_orig ,
3890
3891
float freq_base ,
@@ -3898,15 +3899,19 @@ static struct ggml_tensor * ggml_rope_impl(
3898
3899
3899
3900
GGML_ASSERT (ggml_is_vector (b ));
3900
3901
GGML_ASSERT (b -> type == GGML_TYPE_I32 );
3901
- GGML_ASSERT (a -> ne [2 ] == b -> ne [0 ]);
3902
+
3903
+ bool mrope_used = mode & GGML_ROPE_TYPE_MROPE ;
3904
+ if (mrope_used ) {
3905
+ GGML_ASSERT (a -> ne [2 ] * 4 == b -> ne [0 ]); // mrope expecting 4 position ids per token
3906
+ } else {
3907
+ GGML_ASSERT (a -> ne [2 ] == b -> ne [0 ]);
3908
+ }
3902
3909
3903
3910
if (c ) {
3904
3911
GGML_ASSERT (c -> type == GGML_TYPE_F32 );
3905
3912
GGML_ASSERT (c -> ne [0 ] >= n_dims / 2 );
3906
3913
}
3907
3914
3908
- int sections [4 ] = {0 , 0 , 0 , 0 };
3909
-
3910
3915
struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3911
3916
3912
3917
int32_t params [15 ] = { /*n_past*/ 0 , n_dims , mode , /*n_ctx*/ 0 , n_ctx_orig };
@@ -3916,7 +3921,11 @@ static struct ggml_tensor * ggml_rope_impl(
3916
3921
memcpy (params + 8 , & attn_factor , sizeof (float ));
3917
3922
memcpy (params + 9 , & beta_fast , sizeof (float ));
3918
3923
memcpy (params + 10 , & beta_slow , sizeof (float ));
3919
- memcpy (params + 11 , & sections , sizeof (int )* 4 );
3924
+ if (mrope_used ) {
3925
+ memcpy (params + 11 , sections , sizeof (int32_t ) * GGML_MROPE_SECTIONS );
3926
+ } else {
3927
+ memset (params + 11 , 0 , sizeof (int32_t ) * GGML_MROPE_SECTIONS );
3928
+ }
3920
3929
ggml_set_op_params (result , params , sizeof (params ));
3921
3930
3922
3931
result -> op = GGML_OP_ROPE ;
@@ -3934,7 +3943,7 @@ struct ggml_tensor * ggml_rope(
3934
3943
int n_dims ,
3935
3944
int mode ) {
3936
3945
return ggml_rope_impl (
3937
- ctx , a , b , NULL , n_dims , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , false
3946
+ ctx , a , b , NULL , n_dims , NULL , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , false
3938
3947
);
3939
3948
}
3940
3949
@@ -3944,7 +3953,7 @@ struct ggml_tensor * ggml_rope_multi(
3944
3953
struct ggml_tensor * b ,
3945
3954
struct ggml_tensor * c ,
3946
3955
int n_dims ,
3947
- int sections [4 ],
3956
+ int sections [GGML_MROPE_SECTIONS ],
3948
3957
int mode ,
3949
3958
int n_ctx_orig ,
3950
3959
float freq_base ,
@@ -3953,36 +3962,31 @@ struct ggml_tensor * ggml_rope_multi(
3953
3962
float attn_factor ,
3954
3963
float beta_fast ,
3955
3964
float beta_slow ) {
3956
- // Multimodal Rotary Position Embedding
3957
- GGML_ASSERT ((mode & 1 ) == 0 && "mode & 1 == 1 is no longer supported" );
3958
-
3959
- GGML_ASSERT (ggml_is_vector (b ));
3960
- GGML_ASSERT (b -> type == GGML_TYPE_I32 );
3961
- GGML_ASSERT (a -> ne [2 ] * 4 == b -> ne [0 ]); // mrope expecting 4 position ids per token
3962
-
3963
- if (c ) {
3964
- GGML_ASSERT (c -> type == GGML_TYPE_F32 );
3965
- GGML_ASSERT (c -> ne [0 ] >= n_dims / 2 );
3966
- }
3967
-
3968
- struct ggml_tensor * result = ggml_dup_tensor (ctx , a );
3969
-
3970
- int32_t params [11 + 4 ] = { /*n_past*/ 0 , n_dims , mode , /*n_ctx*/ 0 , n_ctx_orig };
3971
- memcpy (params + 5 , & freq_base , sizeof (float ));
3972
- memcpy (params + 6 , & freq_scale , sizeof (float ));
3973
- memcpy (params + 7 , & ext_factor , sizeof (float ));
3974
- memcpy (params + 8 , & attn_factor , sizeof (float ));
3975
- memcpy (params + 9 , & beta_fast , sizeof (float ));
3976
- memcpy (params + 10 , & beta_slow , sizeof (float ));
3977
- memcpy (& params [11 ], sections , sizeof (int )* 4 );
3978
- ggml_set_op_params (result , params , sizeof (params ));
3979
-
3980
- result -> op = GGML_OP_ROPE ;
3981
- result -> src [0 ] = a ;
3982
- result -> src [1 ] = b ;
3983
- result -> src [2 ] = c ;
3965
+ return ggml_rope_impl (
3966
+ ctx , a , b , c , n_dims , sections , mode , n_ctx_orig , freq_base , freq_scale ,
3967
+ ext_factor , attn_factor , beta_fast , beta_slow , false
3968
+ );
3969
+ }
3984
3970
3985
- return result ;
3971
+ struct ggml_tensor * ggml_rope_multi_inplace (
3972
+ struct ggml_context * ctx ,
3973
+ struct ggml_tensor * a ,
3974
+ struct ggml_tensor * b ,
3975
+ struct ggml_tensor * c ,
3976
+ int n_dims ,
3977
+ int sections [GGML_MROPE_SECTIONS ],
3978
+ int mode ,
3979
+ int n_ctx_orig ,
3980
+ float freq_base ,
3981
+ float freq_scale ,
3982
+ float ext_factor ,
3983
+ float attn_factor ,
3984
+ float beta_fast ,
3985
+ float beta_slow ) {
3986
+ return ggml_rope_impl (
3987
+ ctx , a , b , c , n_dims , sections , mode , n_ctx_orig , freq_base , freq_scale ,
3988
+ ext_factor , attn_factor , beta_fast , beta_slow , true
3989
+ );
3986
3990
}
3987
3991
3988
3992
struct ggml_tensor * ggml_rope_inplace (
@@ -3992,7 +3996,7 @@ struct ggml_tensor * ggml_rope_inplace(
3992
3996
int n_dims ,
3993
3997
int mode ) {
3994
3998
return ggml_rope_impl (
3995
- ctx , a , b , NULL , n_dims , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , true
3999
+ ctx , a , b , NULL , n_dims , NULL , mode , 0 , 10000.0f , 1.0f , 0.0f , 1.0f , 0.0f , 0.0f , true
3996
4000
);
3997
4001
}
3998
4002
@@ -4011,7 +4015,7 @@ struct ggml_tensor * ggml_rope_ext(
4011
4015
float beta_fast ,
4012
4016
float beta_slow ) {
4013
4017
return ggml_rope_impl (
4014
- ctx , a , b , c , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4018
+ ctx , a , b , c , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
4015
4019
ext_factor , attn_factor , beta_fast , beta_slow , false
4016
4020
);
4017
4021
}
@@ -4031,7 +4035,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
4031
4035
float beta_fast ,
4032
4036
float beta_slow ) {
4033
4037
return ggml_rope_impl (
4034
- ctx , a , b , c , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4038
+ ctx , a , b , c , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
4035
4039
ext_factor , attn_factor , beta_fast , beta_slow , true
4036
4040
);
4037
4041
}
@@ -4050,7 +4054,7 @@ struct ggml_tensor * ggml_rope_custom(
4050
4054
float beta_fast ,
4051
4055
float beta_slow ) {
4052
4056
return ggml_rope_impl (
4053
- ctx , a , b , NULL , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4057
+ ctx , a , b , NULL , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
4054
4058
ext_factor , attn_factor , beta_fast , beta_slow , false
4055
4059
);
4056
4060
}
@@ -4069,7 +4073,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
4069
4073
float beta_fast ,
4070
4074
float beta_slow ) {
4071
4075
return ggml_rope_impl (
4072
- ctx , a , b , NULL , n_dims , mode , n_ctx_orig , freq_base , freq_scale ,
4076
+ ctx , a , b , NULL , n_dims , NULL , mode , n_ctx_orig , freq_base , freq_scale ,
4073
4077
ext_factor , attn_factor , beta_fast , beta_slow , true
4074
4078
);
4075
4079
}
0 commit comments