@@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
45304530 GGML_TENSOR_BINARY_OP_LOCALS
45314531
45324532 const int64_t nc = ne00;
4533- const int64_t nr = ggml_nelements (src1) ;
4533+ const int64_t nr = ne01 ;
45344534
45354535 assert (ne0 == nc);
4536- assert (ne02 == ne11);
4537- assert (nb00 == sizeof (float ));
4538- assert (ggml_nrows (src0) == nr);
4536+ assert (ne2 == ne02);
4537+ assert (ne3 == ne03);
4538+ assert (src0->type == GGML_TYPE_F32);
4539+ assert (ne02 % ne11 == 0 );
4540+ assert (ne03 % ne12 == 0 );
45394541
45404542 const int ith = params->ith ;
45414543 const int nth = params->nth ;
@@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32(
45474549 const int ir0 = dr*ith;
45484550 const int ir1 = MIN (ir0 + dr, nr);
45494551
4550- for (int64_t i = ir0; i < ir1; ++i) {
4551- const int64_t i12 = i/(ne11*ne10);
4552- const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4553- const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4554- const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4552+ for (int64_t i03 = 0 ; i03 < ne03; ++i03) {
4553+ for (int64_t i02 = 0 ; i02 < ne02; ++i02) {
4554+ for (int64_t i = ir0; i < ir1; ++i) {
4555+ const int64_t i12 = i03%ne12;
4556+ const int64_t i11 = i02%ne11;
4557+ const int64_t i10 = i;
4558+
4559+ const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
45554560
4556- GGML_ASSERT (i01 >= 0 && i01 < ne1);
4561+ GGML_ASSERT (i01 >= 0 && i01 < ne1);
45574562
4558- ggml_cpu_fp32_to_fp16 (
4559- (const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
4560- (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
4563+ ggml_cpu_fp32_to_fp16 (
4564+ (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4565+ (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), nc);
4566+ }
4567+ }
45614568 }
45624569}
45634570
0 commit comments