@@ -1415,15 +1415,35 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
14151415inline static void ggml_vec_set_f16 (const int n , ggml_fp16_t * x , const int32_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
14161416inline static void ggml_vec_set_bf16 (const int n , ggml_bf16_t * x , const ggml_bf16_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
14171417inline static void ggml_vec_add_f32 (const int n , float * z , const float * x , const float * y ) { for (int i = 0 ; i < n ; ++ i ) z [i ] = x [i ] + y [i ]; }
1418+ inline static void ggml_vec_add_f16 (const int n , ggml_fp16_t * z , const ggml_fp16_t * x , const ggml_fp16_t * y ) {
1419+ for (int i = 0 ; i < n ; ++ i ) {
1420+ z [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (x [i ]) + GGML_FP16_TO_FP32 (y [i ]));
1421+ }
1422+ }
14181423inline static void ggml_vec_add1_f32 (const int n , float * z , const float * x , const float v ) { for (int i = 0 ; i < n ; ++ i ) z [i ] = x [i ] + v ; }
14191424inline static void ggml_vec_acc_f32 (const int n , float * y , const float * x ) { for (int i = 0 ; i < n ; ++ i ) y [i ] += x [i ]; }
14201425inline static void ggml_vec_acc1_f32 (const int n , float * y , const float v ) { for (int i = 0 ; i < n ; ++ i ) y [i ] += v ; }
14211426inline static void ggml_vec_sub_f32 (const int n , float * z , const float * x , const float * y ) { for (int i = 0 ; i < n ; ++ i ) z [i ] = x [i ] - y [i ]; }
1427+ inline static void ggml_vec_sub_f16 (const int n , ggml_fp16_t * z , const ggml_fp16_t * x , const ggml_fp16_t * y ) {
1428+ for (int i = 0 ; i < n ; ++ i ) {
1429+ z [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (x [i ]) - GGML_FP16_TO_FP32 (y [i ]));
1430+ }
1431+ }
14221432inline static void ggml_vec_set_f32 (const int n , float * x , const float v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
14231433inline static void ggml_vec_cpy_f32 (const int n , float * y , const float * x ) { for (int i = 0 ; i < n ; ++ i ) y [i ] = x [i ]; }
14241434inline static void ggml_vec_neg_f32 (const int n , float * y , const float * x ) { for (int i = 0 ; i < n ; ++ i ) y [i ] = - x [i ]; }
14251435inline static void ggml_vec_mul_f32 (const int n , float * z , const float * x , const float * y ) { for (int i = 0 ; i < n ; ++ i ) z [i ] = x [i ]* y [i ]; }
1436+ inline static void ggml_vec_mul_f16 (const int n , ggml_fp16_t * z , const ggml_fp16_t * x , const ggml_fp16_t * y ) {
1437+ for (int i = 0 ; i < n ; ++ i ) {
1438+ z [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (x [i ]) * GGML_FP16_TO_FP32 (y [i ]));
1439+ }
1440+ }
14261441inline static void ggml_vec_div_f32 (const int n , float * z , const float * x , const float * y ) { for (int i = 0 ; i < n ; ++ i ) z [i ] = x [i ]/y [i ]; }
1442+ inline static void ggml_vec_div_f16 (const int n , ggml_fp16_t * z , const ggml_fp16_t * x , const ggml_fp16_t * y ) {
1443+ for (int i = 0 ; i < n ; ++ i ) {
1444+ z [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (x [i ]) / GGML_FP16_TO_FP32 (y [i ]));
1445+ }
1446+ }
14271447
14281448static void ggml_vec_dot_f32 (int n , float * restrict s , size_t bs , const float * restrict x , size_t bx , const float * restrict y , size_t by , int nrc ) {
14291449 assert (nrc == 1 );
@@ -4379,7 +4399,7 @@ static void ggml_compute_forward_add_f16_f16(
43794399 const struct ggml_tensor * src0 = dst -> src [0 ];
43804400 const struct ggml_tensor * src1 = dst -> src [1 ];
43814401
4382- GGML_ASSERT (ggml_are_same_shape ( src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
4402+ GGML_ASSERT (ggml_can_repeat ( src1 , src0 ) && ggml_are_same_shape (src0 , dst ));
43834403
43844404 const int ith = params -> ith ;
43854405 const int nth = params -> nth ;
@@ -4404,17 +4424,22 @@ static void ggml_compute_forward_add_f16_f16(
44044424
44054425 if (nb10 == sizeof (ggml_fp16_t )) {
44064426 for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
4407- // src0, src1 and dst are same shape => same indices
4408- const int i3 = ir /(ne2 * ne1 );
4409- const int i2 = (ir - i3 * ne2 * ne1 )/ ne1 ;
4410- const int i1 = (ir - i3 * ne2 * ne1 - i2 * ne1 );
4427+ // src1 is broadcastable across src0 and dst in i1, i2, i3
4428+ const int64_t i03 = ir /(ne02 * ne01 );
4429+ const int64_t i02 = (ir - i03 * ne02 * ne01 )/ ne01 ;
4430+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
44114431
4412- ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 );
4413- ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i3 * nb03 + i2 * nb02 + i1 * nb01 );
4414- ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + i3 * nb13 + i2 * nb12 + i1 * nb11 );
4432+ const int64_t i13 = i03 % ne13 ;
4433+ const int64_t i12 = i02 % ne12 ;
4434+ const int64_t i11 = i01 % ne11 ;
4435+ const int64_t nr0 = ne00 / ne10 ;
44154436
4416- for (int i = 0 ; i < ne0 ; i ++ ) {
4417- dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (src1_ptr [i ]));
4437+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + i03 * nb3 + i02 * nb2 + i01 * nb1 );
4438+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i03 * nb03 + i02 * nb02 + i01 * nb01 );
4439+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 );
4440+
4441+ for (int64_t r = 0 ; r < nr0 ; ++ r ) {
4442+ ggml_vec_add_f16 (ne10 , dst_ptr + r * ne10 , src0_ptr + r * ne10 , src1_ptr );
44184443 }
44194444 }
44204445 }
@@ -5202,6 +5227,62 @@ static void ggml_compute_forward_sub_f32(
52025227 }
52035228}
52045229
5230+ static void ggml_compute_forward_sub_f16 (
5231+ const struct ggml_compute_params * params ,
5232+ struct ggml_tensor * dst ) {
5233+
5234+ const struct ggml_tensor * src0 = dst -> src [0 ];
5235+ const struct ggml_tensor * src1 = dst -> src [1 ];
5236+
5237+ assert (ggml_can_repeat (src1 , src0 ) && ggml_are_same_shape (src0 , dst ));
5238+
5239+ const int ith = params -> ith ;
5240+ const int nth = params -> nth ;
5241+
5242+ const int nr = ggml_nrows (src0 );
5243+
5244+ GGML_TENSOR_BINARY_OP_LOCALS
5245+
5246+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5247+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5248+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5249+
5250+ GGML_ASSERT ( nb0 == sizeof (ggml_fp16_t ));
5251+ GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
5252+
5253+ // rows per thread
5254+ const int dr = (nr + nth - 1 )/nth ;
5255+
5256+ // row range for this thread
5257+ const int ir0 = dr * ith ;
5258+ const int ir1 = MIN (ir0 + dr , nr );
5259+
5260+ if (nb10 == sizeof (ggml_fp16_t )) {
5261+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5262+ // src1 is broadcastable across src0 and dst in i1, i2, i3
5263+ const int64_t i03 = ir /(ne02 * ne01 );
5264+ const int64_t i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5265+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5266+
5267+ const int64_t i13 = i03 % ne13 ;
5268+ const int64_t i12 = i02 % ne12 ;
5269+ const int64_t i11 = i01 % ne11 ;
5270+ const int64_t nr0 = ne00 / ne10 ;
5271+
5272+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + i03 * nb3 + i02 * nb2 + i01 * nb1 );
5273+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i03 * nb03 + i02 * nb02 + i01 * nb01 );
5274+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 );
5275+
5276+ for (int64_t r = 0 ; r < nr0 ; ++ r ) {
5277+ ggml_vec_sub_f16 (ne10 , dst_ptr + r * ne10 , src0_ptr + r * ne10 , src1_ptr );
5278+ }
5279+ }
5280+ } else {
5281+ // src1 is not contiguous
5282+ GGML_ABORT ("unimplemented error" );
5283+ }
5284+ }
5285+
52055286static void ggml_compute_forward_sub (
52065287 const struct ggml_compute_params * params ,
52075288 struct ggml_tensor * dst ) {
@@ -5213,6 +5294,10 @@ static void ggml_compute_forward_sub(
52135294 {
52145295 ggml_compute_forward_sub_f32 (params , dst );
52155296 } break ;
5297+ case GGML_TYPE_F16 :
5298+ {
5299+ ggml_compute_forward_sub_f16 (params , dst );
5300+ } break ;
52165301 default :
52175302 {
52185303 GGML_ABORT ("fatal error" );
@@ -5293,20 +5378,73 @@ static void ggml_compute_forward_mul_f32(
52935378 }
52945379}
52955380
5381+ static void ggml_compute_forward_mul_f16 (
5382+ const struct ggml_compute_params * params ,
5383+ struct ggml_tensor * dst ) {
5384+
5385+ const struct ggml_tensor * src0 = dst -> src [0 ];
5386+ const struct ggml_tensor * src1 = dst -> src [1 ];
5387+
5388+ GGML_ASSERT (ggml_can_repeat (src1 , src0 ) && ggml_are_same_shape (src0 , dst ));
5389+
5390+ const int ith = params -> ith ;
5391+ const int nth = params -> nth ;
5392+
5393+ const int64_t nr = ggml_nrows (src0 );
5394+
5395+ GGML_TENSOR_BINARY_OP_LOCALS
5396+
5397+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5398+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5399+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5400+
5401+ GGML_ASSERT ( nb0 == sizeof (ggml_fp16_t ));
5402+ GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
5403+
5404+ if (nb10 == sizeof (ggml_fp16_t )) {
5405+ for (int64_t ir = ith ; ir < nr ; ir += nth ) {
5406+ // src0 and dst are same shape => same indices
5407+ const int64_t i03 = ir /(ne02 * ne01 );
5408+ const int64_t i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5409+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5410+
5411+ const int64_t i13 = i03 % ne13 ;
5412+ const int64_t i12 = i02 % ne12 ;
5413+ const int64_t i11 = i01 % ne11 ;
5414+ const int64_t nr0 = ne00 / ne10 ;
5415+
5416+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + i03 * nb3 + i02 * nb2 + i01 * nb1 );
5417+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i03 * nb03 + i02 * nb02 + i01 * nb01 );
5418+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 );
5419+
5420+ for (int64_t r = 0 ; r < nr0 ; ++ r ) {
5421+ ggml_vec_mul_f16 (ne10 , dst_ptr + r * ne10 , src0_ptr + r * ne10 , src1_ptr );
5422+ }
5423+ }
5424+ } else {
5425+ // src1 is not contiguous
5426+ GGML_ABORT ("unimplemented error" );
5427+ }
5428+ }
5429+
52965430static void ggml_compute_forward_mul (
52975431 const struct ggml_compute_params * params ,
52985432 struct ggml_tensor * dst ) {
52995433
53005434 const struct ggml_tensor * src0 = dst -> src [0 ];
53015435 const struct ggml_tensor * src1 = dst -> src [1 ];
53025436
5303- GGML_ASSERT (src1 -> type == GGML_TYPE_F32 && "only f32 src1 supported for now" );
5437+ GGML_ASSERT (( src1 -> type == GGML_TYPE_F32 || src1 -> type == GGML_TYPE_F16 ) && "only f32/f16 src1 supported for now" );
53045438
53055439 switch (src0 -> type ) {
53065440 case GGML_TYPE_F32 :
53075441 {
53085442 ggml_compute_forward_mul_f32 (params , dst );
53095443 } break ;
5444+ case GGML_TYPE_F16 :
5445+ {
5446+ ggml_compute_forward_mul_f16 (params , dst );
5447+ } break ;
53105448 default :
53115449 {
53125450 GGML_ABORT ("fatal error" );
@@ -5387,6 +5525,55 @@ static void ggml_compute_forward_div_f32(
53875525 }
53885526}
53895527
5528+ static void ggml_compute_forward_div_f16 (
5529+ const struct ggml_compute_params * params ,
5530+ struct ggml_tensor * dst ) {
5531+
5532+ const struct ggml_tensor * src0 = dst -> src [0 ];
5533+ const struct ggml_tensor * src1 = dst -> src [1 ];
5534+
5535+ GGML_ASSERT (ggml_can_repeat (src1 , src0 ) && ggml_are_same_shape (src0 , dst ));
5536+
5537+ const int ith = params -> ith ;
5538+ const int nth = params -> nth ;
5539+
5540+ const int64_t nr = ggml_nrows (src0 );
5541+
5542+ GGML_TENSOR_BINARY_OP_LOCALS
5543+
5544+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5545+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5546+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5547+
5548+ GGML_ASSERT ( nb0 == sizeof (ggml_fp16_t ));
5549+ GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
5550+
5551+ if (nb10 == sizeof (ggml_fp16_t )) {
5552+ for (int64_t ir = ith ; ir < nr ; ir += nth ) {
5553+ // src0 and dst are same shape => same indices
5554+ const int64_t i03 = ir /(ne02 * ne01 );
5555+ const int64_t i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5556+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5557+
5558+ const int64_t i13 = i03 % ne13 ;
5559+ const int64_t i12 = i02 % ne12 ;
5560+ const int64_t i11 = i01 % ne11 ;
5561+ const int64_t nr0 = ne00 / ne10 ;
5562+
5563+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + i03 * nb3 + i02 * nb2 + i01 * nb1 );
5564+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i03 * nb03 + i02 * nb02 + i01 * nb01 );
5565+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 );
5566+
5567+ for (int64_t r = 0 ; r < nr0 ; ++ r ) {
5568+ ggml_vec_div_f16 (ne10 , dst_ptr + r * ne10 , src0_ptr + r * ne10 , src1_ptr );
5569+ }
5570+ }
5571+ } else {
5572+ // src1 is not contiguous
5573+ GGML_ABORT ("unimplemented error" );
5574+ }
5575+ }
5576+
53905577static void ggml_compute_forward_div (
53915578 const struct ggml_compute_params * params ,
53925579 struct ggml_tensor * dst ) {
@@ -5398,6 +5585,10 @@ static void ggml_compute_forward_div(
53985585 {
53995586 ggml_compute_forward_div_f32 (params , dst );
54005587 } break ;
5588+ case GGML_TYPE_F16 :
5589+ {
5590+ ggml_compute_forward_div_f16 (params , dst );
5591+ } break ;
54015592 default :
54025593 {
54035594 GGML_ABORT ("fatal error" );
0 commit comments