Skip to content

Commit a28df6f

Browse files
committed
sycl
1 parent 92a8738 commit a28df6f

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
16951695
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
16961696
}
16971697

1698-
static void scale_f32(const float * x, float * dst, const float scale, const int k,
1698+
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
16991699
const sycl::nd_item<3> &item_ct1) {
17001700
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
17011701
item_ct1.get_local_id(2);
@@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
17041704
return;
17051705
}
17061706

1707-
dst[i] = scale * x[i];
1707+
dst[i] = scale * x[i] + bias;
17081708
}
17091709

17101710

@@ -1842,15 +1842,15 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
18421842

18431843

18441844

1845-
static void scale_f32_sycl(const float *x, float *dst, const float scale,
1845+
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
18461846
const int k, queue_ptr stream) {
18471847
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
18481848
stream->parallel_for(
18491849
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
18501850
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
18511851
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
18521852
[=](sycl::nd_item<3> item_ct1) {
1853-
scale_f32(x, dst, scale, k, item_ct1);
1853+
scale_f32(x, dst, scale, bias, k, item_ct1);
18541854
});
18551855
}
18561856

@@ -2318,10 +2318,10 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
23182318
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
23192319
float * dst_dd = static_cast<float *>(dst->data);
23202320

2321-
float scale;
2322-
memcpy(&scale, dst->op_params, sizeof(float));
2321+
float scale = ((const float *)(dst->op_params))[0];
2322+
float bias = ((const float *)(dst->op_params))[1];
23232323

2324-
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2324+
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
23252325
/*
23262326
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
23272327
error codes. The call was replaced with 0. You need to rewrite this code.

0 commit comments

Comments
 (0)