Skip to content

Commit d73ba84

Browse files
committed
use a concise expression to further speedup the cuda kernel
1 parent b3cf133 commit d73ba84

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

ggml/src/ggml-cuda/pad_reflect_1d.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,20 @@ static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
3636
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
3737
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
3838

39-
float value;
40-
const int64_t j = i0 - p0;
39+
const int64_t rel_i0 = i0 - p0;
40+
int64_t src_idx;
4141

42-
if (j < 0) { // i0<p0
42+
if (rel_i0 < 0) {
4343
// Left padding - reflect
44-
value = *(const float *) (src0_ptr - j * nb00);
45-
} else if (j < ne00) { //i0 < ne0 - p1
44+
src_idx = -rel_i0;
45+
} else if (rel_i0 < ne00) {
4646
// Middle - copy
47-
value = *(const float *) (src0_ptr + j * nb00);
47+
src_idx = rel_i0;
4848
} else {
4949
// Right padding - reflect
50-
const int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
51-
value = *(const float *) (src0_ptr + src_idx * nb00);
50+
src_idx = 2 * ne00 - 2 - rel_i0;
5251
}
52+
const float value = *(const float *) (src0_ptr + src_idx * nb00);
5353
*(float *) (dst_ptr + i0 * nb0) = value;
5454
}
5555

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6636,8 +6636,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
66366636

66376637
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
66386638
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
6639-
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
66406639
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
6640+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
66416641
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
66426642

66436643
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));

0 commit comments

Comments
 (0)