Skip to content

Commit 1e29faf

Browse files
committed
CUDA: Optimize PAD_REFLECT_1D
feat: add more test cases for PAD_REFLECT_1D
1 parent 704d90c commit 1e29faf

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

ggml/src/ggml-cuda/pad_reflect_1d.cu

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,33 @@ static __global__ void pad_reflect_1d_kernel_f32(
2121

2222
const int64_t i3 = blockIdx.z;
2323
const int64_t i2 = blockIdx.y;
24-
const int64_t i1 = blockIdx.x;
2524

26-
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
25+
const int64_t tile1 = blockIdx.x % ne01; // i1
26+
const int64_t tile0 = blockIdx.x / ne01; // nth i0 tile
27+
const int64_t i1 = tile1;
28+
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
29+
if ( i0 >= ne0 || i1 >= ne01 || i2 >= ne02 || i3 >= ne03 ) {
2730
return;
2831
}
2932

3033
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
31-
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
34+
const char * dst_ptr = (const char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
3235

33-
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
34-
float value;
36+
float value;
37+
const int64_t j = i0 - p0;
3538

36-
if (i0 < p0) {
37-
// Left padding - reflect
38-
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
39-
} else if (i0 < ne0 - p1) {
40-
// Middle - copy
41-
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
42-
} else {
43-
// Right padding - reflect
44-
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
45-
value = *(const float *)(src0_ptr + src_idx * nb00);
46-
}
47-
48-
*(float *)(dst_ptr + i0 * nb0) = value;
39+
if ( j<0 ) {// i0<p0
40+
// Left padding - reflect
41+
value = *(const float *)(src0_ptr - j * nb00);
42+
} else if (j < ne00) { //i0 < ne0 - p1
43+
// Middle - copy
44+
value = *(const float *)(src0_ptr + j * nb00);
45+
} else {
46+
// Right padding - reflect
47+
const int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
48+
value = *(const float *)(src0_ptr + src_idx * nb00);
4949
}
50+
*(float *)(dst_ptr + i0 * nb0) = value;
5051
}
5152

5253
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -67,10 +68,16 @@ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor *
6768

6869
const int64_t ne0 = dst->ne[0];
6970

71+
// sanity: padded length matches
7072
GGML_ASSERT(ne0 == ne00 + p0 + p1);
7173

72-
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
73-
const dim3 grid_dims(ne01, ne02, ne03);
74+
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
75+
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
76+
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
77+
// grid.y covers i2: [ne02]
78+
// grid.z covers i3: [ne03]
79+
const dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, (unsigned)ne03);
80+
const dim3 block_dims((unsigned)bx, 1, 1);
7481

7582
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
7683
src0->data, dst->data,

tests/test-backend-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6495,6 +6495,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
64956495
test_cases.emplace_back(new test_pad());
64966496
test_cases.emplace_back(new test_pad_ext());
64976497
test_cases.emplace_back(new test_pad_reflect_1d());
6498+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
64986499
test_cases.emplace_back(new test_roll());
64996500
test_cases.emplace_back(new test_arange());
65006501
test_cases.emplace_back(new test_timestep_embedding());
@@ -6633,6 +6634,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
66336634
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
66346635
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
66356636

6637+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
6638+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
6639+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
6640+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
6641+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
6642+
66366643
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
66376644
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
66386645

0 commit comments

Comments
 (0)