@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
13741374
13751375inline static void ggml_vec_set_i8 (const int n , int8_t * x , const int8_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
13761376inline static void ggml_vec_set_i16 (const int n , int16_t * x , const int16_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
1377- inline static void ggml_vec_set_i32 (const int n , int32_t * x , const int32_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
1377+
1378+ inline static void ggml_vec_set_i32 (const int n , int32_t * x , const int32_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
1379+ inline static void ggml_vec_cpy_i32 (const int n , int32_t * y , const int32_t * x ) { for (int i = 0 ; i < n ; ++ i ) y [i ] = x [i ]; }
1380+
13781381inline static void ggml_vec_set_f16 (const int n , ggml_fp16_t * x , const int32_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
13791382inline static void ggml_vec_set_bf16 (const int n , ggml_bf16_t * x , const ggml_bf16_t v ) { for (int i = 0 ; i < n ; ++ i ) x [i ] = v ; }
13801383inline static void ggml_vec_add_f32 (const int n , float * z , const float * x , const float * y ) { for (int i = 0 ; i < n ; ++ i ) z [i ] = x [i ] + y [i ]; }
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
82488251 }
82498252}
82508253
8254+ static void ggml_compute_forward_set_i32 (
8255+ const struct ggml_compute_params * params ,
8256+ struct ggml_tensor * dst ) {
8257+
8258+ const struct ggml_tensor * src0 = dst -> src [0 ];
8259+ const struct ggml_tensor * src1 = dst -> src [1 ];
8260+
8261+ GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
8262+ GGML_ASSERT (ggml_is_contiguous (dst ) && ggml_is_contiguous (src0 ));
8263+
8264+ // view src0 and dst with these strides and data offset inbytes during set
8265+ // nb0 is implicitly element_size because src0 and dst are contiguous
8266+ size_t nb1 = ((int32_t * ) dst -> op_params )[0 ];
8267+ size_t nb2 = ((int32_t * ) dst -> op_params )[1 ];
8268+ size_t nb3 = ((int32_t * ) dst -> op_params )[2 ];
8269+ size_t offset = ((int32_t * ) dst -> op_params )[3 ];
8270+ bool inplace = (bool ) ((int32_t * ) dst -> op_params )[4 ];
8271+
8272+ if (!inplace ) {
8273+ if (params -> ith == 0 ) {
8274+ // memcpy needs to be synchronized across threads to avoid race conditions.
8275+ // => do it in INIT phase
8276+ memcpy (
8277+ ((char * ) dst -> data ),
8278+ ((char * ) src0 -> data ),
8279+ ggml_nbytes (dst ));
8280+ }
8281+ ggml_barrier (params -> threadpool );
8282+ }
8283+
8284+ const int ith = params -> ith ;
8285+ const int nth = params -> nth ;
8286+
8287+ const int nr = ggml_nrows (src1 );
8288+ const int nc = src1 -> ne [0 ];
8289+
8290+ GGML_TENSOR_LOCALS (int64_t , ne1 , src1 , ne )
8291+ GGML_TENSOR_LOCALS (size_t , nb1 , src1 , nb )
8292+
8293+ // src0 and dst as viewed during set
8294+ const size_t nb0 = ggml_element_size (src0 );
8295+
8296+ const int im0 = (ne10 == 0 ? 0 : ne10 - 1 );
8297+ const int im1 = (ne11 == 0 ? 0 : ne11 - 1 );
8298+ const int im2 = (ne12 == 0 ? 0 : ne12 - 1 );
8299+ const int im3 = (ne13 == 0 ? 0 : ne13 - 1 );
8300+
8301+ GGML_ASSERT (offset + im0 * nb0 + im1 * nb1 + im2 * nb2 + im3 * nb3 <= ggml_nbytes (dst ));
8302+
8303+ GGML_ASSERT (nb10 == sizeof (int32_t ));
8304+
8305+ // rows per thread
8306+ const int dr = (nr + nth - 1 )/nth ;
8307+
8308+ // row range for this thread
8309+ const int ir0 = dr * ith ;
8310+ const int ir1 = MIN (ir0 + dr , nr );
8311+
8312+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
8313+ // src0 and dst are viewed with shape of src1 and offset
8314+ // => same indices
8315+ const int i3 = ir /(ne12 * ne11 );
8316+ const int i2 = (ir - i3 * ne12 * ne11 )/ne11 ;
8317+ const int i1 = (ir - i3 * ne12 * ne11 - i2 * ne11 );
8318+
8319+ ggml_vec_cpy_i32 (nc ,
8320+ (int32_t * ) ((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + offset ),
8321+ (int32_t * ) ((char * ) src1 -> data + i3 * nb13 + i2 * nb12 + i1 * nb11 ));
8322+ }
8323+ }
8324+
82518325static void ggml_compute_forward_set (
82528326 const struct ggml_compute_params * params ,
82538327 struct ggml_tensor * dst ) {
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
82598333 {
82608334 ggml_compute_forward_set_f32 (params , dst );
82618335 } break ;
8336+ case GGML_TYPE_I32 :
8337+ {
8338+ ggml_compute_forward_set_i32 (params , dst );
8339+ } break ;
82628340 case GGML_TYPE_F16 :
82638341 case GGML_TYPE_BF16 :
82648342 case GGML_TYPE_Q4_0 :
0 commit comments