@@ -2858,12 +2858,14 @@ static struct ggml_tensor * ggml_scale_impl(
28582858 struct ggml_context * ctx ,
28592859 struct ggml_tensor * a ,
28602860 float s ,
2861+ float b ,
28612862 bool inplace ) {
28622863 GGML_ASSERT (ggml_is_padded_1d (a ));
28632864
28642865 struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
28652866
2866- ggml_set_op_params (result , & s , sizeof (s ));
2867+ float params [2 ] = { s , b };
2868+ ggml_set_op_params (result , & params , sizeof (params ));
28672869
28682870 result -> op = GGML_OP_SCALE ;
28692871 result -> src [0 ] = a ;
@@ -2875,14 +2877,30 @@ struct ggml_tensor * ggml_scale(
28752877 struct ggml_context * ctx ,
28762878 struct ggml_tensor * a ,
28772879 float s ) {
2878- return ggml_scale_impl (ctx , a , s , false);
2880+ return ggml_scale_impl (ctx , a , s , 0.0 , false);
28792881}
28802882
28812883struct ggml_tensor * ggml_scale_inplace (
28822884 struct ggml_context * ctx ,
28832885 struct ggml_tensor * a ,
28842886 float s ) {
2885- return ggml_scale_impl (ctx , a , s , true);
2887+ return ggml_scale_impl (ctx , a , s , 0.0 , true);
2888+ }
2889+
2890+ struct ggml_tensor * ggml_scale_bias (
2891+ struct ggml_context * ctx ,
2892+ struct ggml_tensor * a ,
2893+ float s ,
2894+ float b ) {
2895+ return ggml_scale_impl (ctx , a , s , b , false);
2896+ }
2897+
2898+ struct ggml_tensor * ggml_scale_bias_inplace (
2899+ struct ggml_context * ctx ,
2900+ struct ggml_tensor * a ,
2901+ float s ,
2902+ float b ) {
2903+ return ggml_scale_impl (ctx , a , s , b , true);
28862904}
28872905
28882906// ggml_set
@@ -5472,7 +5490,7 @@ static void ggml_compute_backward(
54725490 } break ;
54735491 case GGML_OP_MEAN : {
54745492 if (src0_needs_grads ) {
5475- ggml_add1_or_set (ctx , cgraph , isrc0 , ggml_scale_impl (ctx , grad , 1.0f /src0 -> ne [0 ], false));
5493+ ggml_add1_or_set (ctx , cgraph , isrc0 , ggml_scale_impl (ctx , grad , 1.0f /src0 -> ne [0 ], 0.0 , false));
54765494 }
54775495 } break ;
54785496 case GGML_OP_REPEAT : {
@@ -5549,7 +5567,7 @@ static void ggml_compute_backward(
55495567 if (src0_needs_grads ) {
55505568 float s ;
55515569 memcpy (& s , tensor -> op_params , sizeof (float ));
5552- ggml_add_or_set (ctx , cgraph , isrc0 , ggml_scale_impl (ctx , grad , s , false));
5570+ ggml_add_or_set (ctx , cgraph , isrc0 , ggml_scale_impl (ctx , grad , s , 0.0 , false));
55535571 }
55545572 } break ;
55555573 case GGML_OP_SET : {
0 commit comments