Skip to content

Commit 0ed09a1

Browse files
committed
sycl: add PAD_REFLECT_D1 operator support
1 parent 1eeb523 commit 0ed09a1

File tree

5 files changed

+95
-1
lines changed

5 files changed

+95
-1
lines changed

docs/ops/SYCL.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6025,7 +6025,7 @@
60256025
"SYCL0","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","SYCL"
60266026
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","SYCL"
60276027
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","SYCL"
6028-
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","SYCL"
6028+
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
60296029
"SYCL0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","SYCL"
60306030
"SYCL0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","SYCL"
60316031
"SYCL0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","SYCL"

ggml/src/ggml-sycl/backend.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,7 @@
3535
#include "softmax.hpp"
3636
#include "tsembd.hpp"
3737
#include "wkv.hpp"
38+
#include "pad_reflect_d1.hpp"
39+
3840

3941
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3673,6 +3673,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36733673
case GGML_OP_CONCAT:
36743674
ggml_sycl_op_concat(ctx, dst);
36753675
break;
3676+
case GGML_OP_PAD_REFLECT_1D:
3677+
ggml_sycl_op_pad_reflect_d1(ctx,dst);
3678+
break;
36763679
case GGML_OP_UPSCALE:
36773680
ggml_sycl_upscale(ctx, dst);
36783681
break;
@@ -4369,6 +4372,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43694372
case GGML_OP_SIN:
43704373
case GGML_OP_COS:
43714374
case GGML_OP_CLAMP:
4375+
case GGML_OP_PAD_REFLECT_1D:
4376+
return ggml_is_contiguous(op->src[0]) &&
4377+
op-> type == GGML_TYPE_F32 &&
4378+
op->src[0]->type == GGML_TYPE_F32;
43724379
case GGML_OP_LOG:
43734380
#if defined (GGML_SYCL_F16)
43744381
return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "pad_reflect_d1.hpp"
2+
3+
void pad_reflect_d1_f32(const float* src,float* dst,
4+
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
5+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
6+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
7+
const sycl::nd_item<3> &item_ct1){
8+
9+
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
10+
const int i1 = item_ct1.get_group(1);
11+
const int g2 = item_ct1.get_group(2);
12+
const int i2 = g2 % ne02;
13+
const int i3 = g2 / ne02;
14+
15+
if (i0 >= p0 + ne0 + p1) return;
16+
17+
int t = i0 - p0;
18+
int period = 2 * ne0 -2;
19+
int m = t % period;
20+
m += (m < 0) * period;
21+
int center = ne0 -1;
22+
int srci0 = center - abs(center - m);
23+
24+
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
25+
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
26+
dst[offest_dst] = src[offest_src];
27+
28+
}
29+
30+
void ggml_sycl_op_pad_reflect_d1(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
31+
32+
const ggml_tensor * src0 = dst->src[0];
33+
queue_ptr stream = ctx.stream();
34+
35+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
36+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
37+
38+
const int32_t * opts = (const int32_t *) dst->op_params;
39+
const int p0 = opts[0];
40+
const int p1 = opts[1];
41+
42+
const int64_t ne0 = src0->ne[0];
43+
44+
const int64_t ne00 = dst->ne[0];
45+
const int64_t ne01 = dst->ne[1];
46+
const int64_t ne02 = dst->ne[2];
47+
const int64_t ne03 = dst->ne[3];
48+
49+
const int64_t nb00 = dst->nb[0];
50+
const int64_t nb01 = dst->nb[1];
51+
const int64_t nb02 = dst->nb[2];
52+
const int64_t nb03 = dst->nb[3];
53+
const int64_t nb0 = src0->nb[0];
54+
const int64_t nb1 = src0->nb[1];
55+
const int64_t nb2 = src0->nb[2];
56+
const int64_t nb3 = src0->nb[3];
57+
58+
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
59+
60+
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
61+
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
62+
63+
stream->parallel_for(
64+
sycl::nd_range<3>(global,
65+
local),
66+
[=](sycl::nd_item<3> item_ct1) { pad_reflect_d1_f32(
67+
(const float *) src0->data, (float *) dst->data,
68+
ne0, ne02, p0, p1,
69+
nb0, nb1, nb2, nb3,
70+
nb00, nb01, nb02, nb03
71+
, item_ct1);
72+
});
73+
}
74+
75+
76+
77+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_PAD_REFLECT_D1_HPP
2+
#define GGML_SYCL_PAD_REFLECT_D1_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_pad_reflect_d1(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
7+
8+
#endif // GGML_SYCL_PAD_REFLECT_D1_HPP

0 commit comments

Comments
 (0)