Skip to content

Commit d1b1757

Browse files
cpu: support incontiguous inputs for get_rel_pos
1 parent 5d279bd commit d1b1757

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9346,14 +9346,15 @@ static void ggml_compute_forward_get_rel_pos_f32(
93469346
const float k_scale = MAX((float)qh / kh, 1.0f);
93479347
const float q_scale = MAX((float)kh / qh, 1.0f);
93489348

9349-
float * src0_data = (float *) src0->data;
9350-
float * dst_data = (float *) dst->data;
9349+
const char * src0_d = (const char *) src0->data;
9350+
char * dst_d = (char *) dst->data;
93519351

93529352
for (int64_t i2 = 0; i2 < ne2; ++i2) {
93539353
for (int64_t i1 = 0; i1 < ne1; ++i1) {
93549354
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
93559355
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9356-
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9356+
const float val = *(const float *) (src0_d + pos*nb01 + i0*nb00);
9357+
*(float *) (dst_d + i2*nb2 + i1*nb1 + i0*nb0) = val;
93579358
}
93589359
}
93599360
}
@@ -9375,14 +9376,15 @@ static void ggml_compute_forward_get_rel_pos_f16(
93759376
const float k_scale = MAX((float)qh / kh, 1.0f);
93769377
const float q_scale = MAX((float)kh / qh, 1.0f);
93779378

9378-
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
9379-
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
9379+
const char * src0_d = (const char *) src0->data;
9380+
char * dst_d = (char *) dst->data;
93809381

93819382
for (int64_t i2 = 0; i2 < ne2; ++i2) {
93829383
for (int64_t i1 = 0; i1 < ne1; ++i1) {
93839384
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
93849385
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9385-
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9386+
const ggml_fp16_t val = *(const ggml_fp16_t *) (src0_d + pos*nb01 + i0*nb00);
9387+
*(ggml_fp16_t *) (dst_d + i2*nb2 + i1*nb1 + i0*nb0) = val;
93869388
}
93879389
}
93889390
}

0 commit comments

Comments
 (0)