@@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
66916691 const struct ggml_compute_params * params ,
66926692 struct ggml_tensor * dst ) {
66936693
6694- const struct ggml_tensor * src0 = dst -> src [0 ];
6695- const struct ggml_tensor * grad = dst -> src [1 ];
6694+ const struct ggml_tensor * grad = dst -> src [0 ];
6695+ const struct ggml_tensor * src1 = dst -> src [1 ];
66966696
66976697 assert (ggml_is_contiguous_1 (grad ));
6698- assert (ggml_is_contiguous_1 (src0 ));
6698+ assert (ggml_is_contiguous_1 (src1 ));
66996699 assert (ggml_is_contiguous_1 (dst ));
6700- assert (ggml_are_same_shape (src0 , dst ));
6701- assert (ggml_are_same_shape (src0 , grad ));
6700+ assert (ggml_are_same_shape (src1 , dst ));
6701+ assert (ggml_are_same_shape (src1 , grad ));
67026702
67036703 const int ith = params -> ith ;
67046704 const int nth = params -> nth ;
67056705
6706- const int nc = src0 -> ne [0 ];
6707- const int nr = ggml_nrows (src0 );
6706+ const int nc = src1 -> ne [0 ];
6707+ const int nr = ggml_nrows (src1 );
67086708
67096709 // rows per thread
67106710 const int dr = (nr + nth - 1 )/nth ;
@@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
67166716 for (int i1 = ir0 ; i1 < ir1 ; i1 ++ ) {
67176717 ggml_vec_silu_backward_f32 (nc ,
67186718 (float * ) ((char * ) dst -> data + i1 * ( dst -> nb [1 ])),
6719- (float * ) ((char * ) src0 -> data + i1 * (src0 -> nb [1 ])),
6719+ (float * ) ((char * ) src1 -> data + i1 * (src1 -> nb [1 ])),
67206720 (float * ) ((char * ) grad -> data + i1 * (grad -> nb [1 ])));
67216721
67226722#ifndef NDEBUG
@@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
68956895 float eps ;
68966896 memcpy (& eps , dst -> op_params , sizeof (float ));
68976897
6898- GGML_ASSERT (eps > 0.0f );
6898+ GGML_ASSERT (eps >= 0.0f );
68996899
69006900 // TODO: optimize
69016901 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
@@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
69666966 float eps ;
69676967 memcpy (& eps , dst -> op_params , sizeof (float ));
69686968
6969- GGML_ASSERT (eps > 0.0f );
6969+ GGML_ASSERT (eps >= 0.0f );
69706970
69716971 // TODO: optimize
69726972 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
@@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
70187018 const struct ggml_compute_params * params ,
70197019 struct ggml_tensor * dst ) {
70207020
7021- const struct ggml_tensor * src0 = dst -> src [0 ];
7022- const struct ggml_tensor * src1 = dst -> src [1 ];
7021+ const struct ggml_tensor * src0 = dst -> src [0 ]; // gradients from forward pass output
7022+ const struct ggml_tensor * src1 = dst -> src [1 ]; // src1 from forward pass
70237023
70247024 GGML_ASSERT (ggml_are_same_shape (src0 , dst ) && ggml_are_same_shape (src0 , src1 ));
70257025
70267026 GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
7027+ GGML_ASSERT (src1 -> nb [0 ] == sizeof (float ));
70277028
70287029 const int ith = params -> ith ;
70297030 const int nth = params -> nth ;
@@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
70427043 const int64_t i12 = i02 ;
70437044 const int64_t i13 = i03 ;
70447045
7045- const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
7046- const float * dz = (float * ) ((char * ) src1 -> data + i11 * nb11 + i12 * nb12 + i13 * nb13 );
7046+ const float * dz = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
7047+ const float * x = (float * ) ((char * ) src1 -> data + i11 * nb11 + i12 * nb12 + i13 * nb13 );
70477048
70487049 ggml_float sum_xx = 0.0 ;
70497050 ggml_float sum_xdz = 0.0 ;
@@ -7066,23 +7067,23 @@ static void ggml_compute_forward_rms_norm_back_f32(
70667067 {
70677068 // z = rms_norm(x)
70687069 //
7069- // rms_norm(src0 ) =
7070+ // rms_norm(src1 ) =
70707071 // scale(
7071- // src0 ,
7072+ // src1 ,
70727073 // div(
70737074 // 1,
70747075 // sqrt(
70757076 // add(
70767077 // scale(
70777078 // sum(
70787079 // sqr(
7079- // src0 )),
7080+ // src1 )),
70807081 // (1.0/N)),
70817082 // eps))));
70827083
70837084 // postorder:
70847085 // ## op args grad
7085- // 00 param src0 grad[#00]
7086+ // 00 param src1 grad[#00]
70867087 // 01 const 1
70877088 // 02 sqr (#00) grad[#02]
70887089 // 03 sum (#02) grad[#03]
@@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
71597160 // dx := scale(dx, rrms)
71607161 float * dx = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
71617162
7163+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
71627164 ggml_vec_cpy_f32 (ne00 , dx , x );
71637165 // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
71647166 ggml_vec_scale_f32 (ne00 , dx , (float )(- sum_xdz )/sum_eps );
@@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
77507752 const int ith = params -> ith ;
77517753 const int nth = params -> nth ;
77527754
7753- GGML_ASSERT (ne0 == ne00 );
7754- GGML_ASSERT (ne1 == ne10 );
7755- GGML_ASSERT (ne2 == ne02 );
7756- GGML_ASSERT (ne02 == ne12 );
7757- GGML_ASSERT (ne3 == ne13 );
7758- GGML_ASSERT (ne03 == ne13 );
7755+ GGML_ASSERT (ne0 == ne00 );
7756+ GGML_ASSERT (ne1 == ne10 );
7757+ GGML_ASSERT (ne2 == ne12 );
7758+ GGML_ASSERT (ne3 == ne13 );
7759+
7760+ GGML_ASSERT (ne2 % ne02 == 0 );
7761+ GGML_ASSERT (ne3 % ne03 == 0 );
77597762
77607763 // we don't support permuted src0 or src1
77617764 GGML_ASSERT (nb00 == sizeof (float ));
@@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
77977800 const int64_t blck_0 = MAX (GGML_VEC_MAD_UNROLL , 32 );
77987801 const int64_t blck_1 = 16 ;
77997802
7803+ // dps == dst per src0, used for group query attention
7804+ const int64_t dps2 = ne2 / ne02 ;
7805+ const int64_t dps3 = ne3 / ne03 ;
7806+
78007807 for (int64_t bir = ir0 ; bir < ir1 ; bir += blck_1 ) {
78017808 const int64_t bir1 = MIN (bir + blck_1 , ir1 );
78027809 for (int64_t bi01 = 0 ; bi01 < ne01 ; bi01 += blck_0 ) {
@@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
78077814 const int64_t i2 = (ir - i3 * ne2 * ne1 )/ne1 ;
78087815 const int64_t i1 = (ir - i3 * ne2 * ne1 - i2 * ne1 );
78097816
7810- const int64_t i02 = i2 ;
7811- const int64_t i03 = i3 ;
7817+ const int64_t i02 = i2 / dps2 ;
7818+ const int64_t i03 = i3 / dps3 ;
78127819
78137820 //const int64_t i10 = i1;
78147821 const int64_t i12 = i2 ;
@@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
89068913}
89078914
89088915
8909- // ggml_compute_forward_soft_max_back
8916+ // ggml_compute_forward_soft_max_ext_back
89108917
8911- static void ggml_compute_forward_soft_max_back_f32 (
8918+ static void ggml_compute_forward_soft_max_ext_back_f32 (
89128919 const struct ggml_compute_params * params ,
89138920 struct ggml_tensor * dst ) {
89148921
@@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
89218928 GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
89228929 GGML_ASSERT (ggml_are_same_shape (src1 , dst ));
89238930
8931+ float scale = 1.0f ;
8932+ float max_bias = 0.0f ;
8933+
8934+ memcpy (& scale , (const float * ) dst -> op_params + 0 , sizeof (float ));
8935+ memcpy (& max_bias , (const float * ) dst -> op_params + 1 , sizeof (float ));
8936+
8937+ GGML_ASSERT (max_bias == 0.0f );
8938+
89248939 // TODO: handle transposed/permuted matrices
89258940
89268941 const int ith = params -> ith ;
@@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
89698984
89708985 // linear runtime, no additional memory
89718986 float dot_y_dy = 0 ;
8972- ggml_vec_dot_f32 (nc , & dot_y_dy , 0 , y , 0 , dy , 0 , 1 );
8973- ggml_vec_cpy_f32 (nc , dx , dy );
8974- ggml_vec_acc1_f32 (nc , dx , - dot_y_dy );
8975- ggml_vec_mul_f32 (nc , dx , dx , y );
8987+ ggml_vec_dot_f32 (nc , & dot_y_dy , 0 , y , 0 , dy , 0 , 1 );
8988+ ggml_vec_cpy_f32 (nc , dx , dy );
8989+ ggml_vec_acc1_f32 (nc , dx , - dot_y_dy );
8990+ ggml_vec_mul_f32 (nc , dx , dx , y );
8991+ ggml_vec_scale_f32 (nc , dx , scale );
89768992
89778993#ifndef NDEBUG
89788994 for (int i = 0 ; i < nc ; ++ i ) {
@@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
89838999 }
89849000}
89859001
8986- static void ggml_compute_forward_soft_max_back (
9002+ static void ggml_compute_forward_soft_max_ext_back (
89879003 const struct ggml_compute_params * params ,
89889004 struct ggml_tensor * dst ) {
89899005
@@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
89929008 switch (src0 -> type ) {
89939009 case GGML_TYPE_F32 :
89949010 {
8995- ggml_compute_forward_soft_max_back_f32 (params , dst );
9011+ ggml_compute_forward_soft_max_ext_back_f32 (params , dst );
89969012 } break ;
89979013 default :
89989014 {
@@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
998510001 const struct ggml_compute_params * params ,
998610002 struct ggml_tensor * dst ) {
998710003
9988- const struct ggml_tensor * src0 = dst -> src [0 ];
9989- const struct ggml_tensor * src1 = dst -> src [1 ];
10004+ const struct ggml_tensor * src0 = dst -> src [0 ]; // gradients of forward pass output
10005+ const struct ggml_tensor * src1 = dst -> src [1 ]; // convolution kernel
999010006
10007+ GGML_ASSERT (src0 -> type == GGML_TYPE_F32 );
999110008 GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
999210009 GGML_ASSERT ( dst -> type == GGML_TYPE_F32 );
999310010
@@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
1000910026 const int64_t IH = is_2D ? ne1 : 1 ;
1001010027 const int64_t IW = ne0 ;
1001110028
10012- const int64_t KH = is_2D ? ne01 : 1 ;
10013- const int64_t KW = ne00 ;
10029+ const int64_t KH = is_2D ? ne11 : 1 ;
10030+ const int64_t KW = ne10 ;
1001410031
10015- const int64_t OH = is_2D ? ne12 : 1 ;
10016- const int64_t OW = ne11 ;
10032+ const int64_t OH = is_2D ? ne02 : 1 ;
10033+ const int64_t OW = ne01 ;
1001710034
1001810035 int ofs0 = is_2D ? nb3 : nb2 ;
1001910036 int ofs1 = is_2D ? nb2 : nb1 ;
@@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
1005910076 continue ;
1006010077 }
1006110078
10062- const float * const src_data = (const float * ) src1 -> data
10079+ const float * const grad_in = (const float * ) src0 -> data
1006310080 + (in * OH * OW + ioh * OW + iow )* (IC * KH * KW ); // [IC, KH, KW]
10064- grad += src_data [iic * (KH * KW ) + ikh * KW + ikw ];
10081+ grad += grad_in [iic * (KH * KW ) + ikh * KW + ikw ];
1006510082 }
1006610083 }
1006710084 float * dst_data = (float * )((char * ) wdata + (in * ofs0 + iic * ofs1 )); // [IH, IW]
@@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1248412501 const struct ggml_compute_params * params ,
1248512502 struct ggml_tensor * dst ) {
1248612503
12487- const struct ggml_tensor * src0 = dst -> src [0 ];
12488- const struct ggml_tensor * src1 = dst -> src [1 ];
12489- const struct ggml_tensor * opt0 = dst -> src [2 ];
12504+ const struct ggml_tensor * grad = dst -> src [0 ]; // gradient of forward pass output
12505+ const struct ggml_tensor * src0f = dst -> src [1 ]; // src0 of forward pass
12506+ const struct ggml_tensor * src1f = dst -> src [2 ]; // src1 of forward pass
1249012507
1249112508 GGML_ASSERT (ggml_is_contiguous (dst ));
12492- GGML_ASSERT (ggml_is_contiguous (src0 ));
12493- GGML_ASSERT (ggml_is_contiguous (src1 ));
12494- GGML_ASSERT (ggml_is_contiguous (opt0 ));
12495- GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
12509+ GGML_ASSERT (ggml_is_contiguous (src0f ));
12510+ GGML_ASSERT (ggml_is_contiguous (src1f ));
12511+ GGML_ASSERT (ggml_is_contiguous (grad ));
12512+ GGML_ASSERT (ggml_are_same_shape (src0f , src1f ) && ggml_are_same_shape (src0f , dst ));
1249612513
1249712514 const int64_t ith = params -> ith ;
1249812515 const int64_t nth = params -> nth ;
1249912516
1250012517 // TODO: handle transposed/permuted matrices
12501- const int64_t nc = src0 -> ne [0 ];
12502- const int64_t nr = ggml_nrows (src0 );
12518+ const int64_t nc = src0f -> ne [0 ];
12519+ const int64_t nr = ggml_nrows (src0f );
1250312520
1250412521 // rows per thread
1250512522 const int64_t dr = (nr + nth - 1 )/nth ;
@@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1250812525 const int64_t ir0 = dr * ith ;
1250912526 const int64_t ir1 = MIN (ir0 + dr , nr );
1251012527
12511- const float d_by_nr = ((const float * ) opt0 -> data )[0 ] / (float ) nr ;
12528+ const float d_by_nr = ((const float * ) grad -> data )[0 ] / (float ) nr ;
1251212529
1251312530 for (int64_t i1 = ir0 ; i1 < ir1 ; i1 ++ ) {
12514- float * ds0 = (float * )((char * ) dst -> data + i1 * dst -> nb [1 ]);
12515- float * s0 = (float * )((char * ) src0 -> data + i1 * src0 -> nb [1 ]);
12516- float * s1 = (float * )((char * ) src1 -> data + i1 * src1 -> nb [1 ]);
12531+ float * ds0 = (float * )((char * ) dst -> data + i1 * dst -> nb [1 ]);
12532+ const float * s0 = (const float * )((const char * ) src0f -> data + i1 * src0f -> nb [1 ]);
12533+ const float * s1 = (const float * )((const char * ) src1f -> data + i1 * src1f -> nb [1 ]);
1251712534
1251812535#ifndef NDEBUG
1251912536 for (int64_t i = 0 ; i < nc ; ++ i ) {
@@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1252612543 // soft_max
1252712544 float max = - INFINITY ;
1252812545 ggml_vec_max_f32 (nc , & max , s0 );
12529- ggml_float sum = ggml_vec_soft_max_f32 (nc , ds0 , s0 , max );
12546+ const ggml_float sum = ggml_vec_soft_max_f32 (nc , ds0 , s0 , max );
1253012547 assert (sum > 0.0 );
1253112548 ggml_vec_scale_f32 (nc , ds0 , 1.0 /sum );
1253212549
12533- // grad(src0 ) = (softmax(src0 ) - src1 ) * grad(cross_entropy_loss(src0, src1 )) / nr
12550+ // grad(src0f ) = (softmax(src0f ) - src1f ) * grad(cross_entropy_loss(src0f, src1f )) / nr
1253412551 ggml_vec_sub_f32 (nc , ds0 , ds0 , s1 );
1253512552 ggml_vec_scale_f32 (nc , ds0 , d_by_nr );
1253612553
@@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1282712844 } break ;
1282812845 case GGML_OP_SOFT_MAX_BACK :
1282912846 {
12830- ggml_compute_forward_soft_max_back (params , tensor );
12847+ ggml_compute_forward_soft_max_ext_back (params , tensor );
1283112848 } break ;
1283212849 case GGML_OP_ROPE :
1283312850 {
0 commit comments