@@ -2303,13 +2303,15 @@ struct ggml_tensor * ggml_repeat(
23032303struct ggml_tensor * ggml_repeat_back (
23042304 struct ggml_context * ctx ,
23052305 struct ggml_tensor * a ,
2306- struct ggml_tensor * b ) {
2306+ struct ggml_tensor * b ,
2307+ bool gqa_mode ) {
23072308 GGML_ASSERT (ggml_can_repeat (b , a ));
23082309
23092310 struct ggml_tensor * result = ggml_new_tensor (ctx , a -> type , GGML_MAX_DIMS , b -> ne );
23102311
23112312 result -> op = GGML_OP_REPEAT_BACK ;
23122313 result -> src [0 ] = a ;
2314+ result -> op_params [1 ] = gqa_mode ? 1 : 0 ;
23132315
23142316 return result ;
23152317}
@@ -5129,7 +5131,7 @@ static void ggml_compute_backward(
51295131 if (src1_needs_grads ) {
51305132 struct ggml_tensor * tmp = grad ;
51315133 if (!ggml_are_same_shape (src0 , src1 )) {
5132- tmp = ggml_repeat_back (ctx , tmp , src1 );
5134+ tmp = ggml_repeat_back (ctx , tmp , src1 , false );
51335135 }
51345136 ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
51355137 }
@@ -5174,7 +5176,7 @@ static void ggml_compute_backward(
51745176 if (src1_needs_grads ) {
51755177 struct ggml_tensor * tmp = ggml_mul (ctx , src0 , grad );
51765178 if (!ggml_are_same_shape (src0 , src1 )) {
5177- tmp = ggml_repeat_back (ctx , tmp , src1 );
5179+ tmp = ggml_repeat_back (ctx , tmp , src1 , false );
51785180 }
51795181 ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
51805182 }
@@ -5229,7 +5231,7 @@ static void ggml_compute_backward(
52295231 } break ;
52305232 case GGML_OP_REPEAT : {
52315233 if (src0_needs_grads ) {
5232- ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 ));
5234+ ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 , false ));
52335235 }
52345236 } break ;
52355237 case GGML_OP_REPEAT_BACK : {
@@ -5268,8 +5270,7 @@ static void ggml_compute_backward(
52685270 if (!ggml_are_same_shape (tmp , src0 )) {
52695271 GGML_ASSERT (tmp -> ne [0 ] == src0 -> ne [0 ]);
52705272 GGML_ASSERT (tmp -> ne [1 ] == src0 -> ne [1 ]);
5271- tmp = ggml_repeat_back (ctx , tmp , src0 );
5272- tmp -> op_params [0 ] = 1 ; // FIXME
5273+ tmp = ggml_repeat_back (ctx , tmp , src0 , true);
52735274 }
52745275 ggml_add_or_set (ctx , cgraph , isrc0 , tmp );
52755276 }
0 commit comments