Skip to content

Commit 084d650

Browse files
authored
Merge pull request #2 from am17an/solve_tri_cuda_opt
optimize
2 parents 4836963 + 42d6d58 commit 084d650

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ static __global__ void solve_tri_f32_fast(
1313
const float* __restrict__ A,
1414
const float* __restrict__ B,
1515
float* __restrict__ X,
16-
const int64_t ne02,
16+
const uint3 ne02,
1717
const size_t nb02, const size_t nb03,
1818
const size_t nb12, const size_t nb13,
1919
const size_t nb2, const size_t nb3) {
@@ -26,8 +26,9 @@ static __global__ void solve_tri_f32_fast(
2626
return;
2727
}
2828

29-
const int64_t i03 = batch_idx / ne02;
30-
const int64_t i02 = batch_idx % ne02;
29+
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
30+
const int64_t i02 = i02_i03.y;
31+
const int64_t i03 = i02_i03.x;
3132

3233
const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03);
3334
const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13);
@@ -37,14 +38,19 @@ static __global__ void solve_tri_f32_fast(
3738
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
3839
__shared__ float sX[MAX_N_FAST * MAX_K_FAST];
3940

41+
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
4042
// Load A into shared memory (coalesced)
41-
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * n; i += blockDim.x * blockDim.y) {
42-
sA[i] = A_batch[i];
43+
#pragma unroll
44+
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
45+
int i0 = i + offset;
46+
sA[i0] = A_batch[i0];
4347
}
4448

4549
// Load B into shared memory (coalesced)
46-
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * k; i += blockDim.x * blockDim.y) {
47-
sX[i] = B_batch[i];
50+
#pragma unroll
51+
for (int i = 0; i < n * k; i += k * WARP_SIZE) {
52+
int i0 = i + threadIdx.x + threadIdx.y * blockDim.x;
53+
sX[i0] = B_batch[i0];
4854
}
4955
__syncthreads();
5056

@@ -74,16 +80,18 @@ static __global__ void solve_tri_f32_fast(
7480
}
7581

7682
// Write results from shared memory to global memory (coalesced)
77-
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * k; i += blockDim.x * blockDim.y) {
78-
X_batch[i] = sX[i];
83+
#pragma unroll
84+
for (int i = 0; i < n * k; i += k * WARP_SIZE) {
85+
const int i0 = i + threadIdx.x + threadIdx.y*blockDim.x;
86+
X_batch[i0] = sX[i0];
7987
}
8088
}
8189

8290
static __global__ void solve_tri_f32_fast_general(
8391
const float* __restrict__ A,
8492
const float* __restrict__ B,
8593
float* __restrict__ X,
86-
const int64_t ne02,
94+
const uint3 ne02,
8795
const size_t nb02, const size_t nb03,
8896
const size_t nb12, const size_t nb13,
8997
const size_t nb2, const size_t nb3,
@@ -97,8 +105,9 @@ static __global__ void solve_tri_f32_fast_general(
97105
return;
98106
}
99107

100-
const int64_t i03 = batch_idx / ne02;
101-
const int64_t i02 = batch_idx % ne02;
108+
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
109+
const int64_t i02 = i02_i03.y;
110+
const int64_t i03 = i02_i03.x;
102111

103112
const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03);
104113
const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13);
@@ -164,44 +173,45 @@ static void solve_tri_f32_cuda(
164173
cudaStream_t stream)
165174
{
166175
// n <= 64, k <= 32
176+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
167177
dim3 threads(WARP_SIZE, k);
168178
dim3 grid(ne02 * ne03);
169179
if (n == 64) {
170180
if (k == 32) {
171181
solve_tri_f32_fast<64, 32><<<grid, threads, 0, stream>>>(
172-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
182+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
173183
} else if (k == 16) {
174184
solve_tri_f32_fast<64, 16><<<grid, threads, 0, stream>>>(
175-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
185+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
176186
} else if (k == 14) {
177187
solve_tri_f32_fast<64, 14><<<grid, threads, 0, stream>>>(
178-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
188+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
179189
} else if (k == 12) {
180190
solve_tri_f32_fast<64, 12><<<grid, threads, 0, stream>>>(
181-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
191+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
182192
} else if (k == 10) {
183193
solve_tri_f32_fast<64, 10><<<grid, threads, 0, stream>>>(
184-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
194+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
185195
} else if (k == 8) {
186196
solve_tri_f32_fast<64, 8><<<grid, threads, 0, stream>>>(
187-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
197+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
188198
} else if (k == 6) {
189199
solve_tri_f32_fast<64, 6><<<grid, threads, 0, stream>>>(
190-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
200+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
191201
} else if (k == 4) {
192202
solve_tri_f32_fast<64, 4><<<grid, threads, 0, stream>>>(
193-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
203+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
194204
} else if (k == 2) {
195205
solve_tri_f32_fast<64, 2><<<grid, threads, 0, stream>>>(
196-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
206+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
197207
} else if (k == 1) {
198208
solve_tri_f32_fast<64, 1><<<grid, threads, 0, stream>>>(
199-
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
209+
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
200210
} else {
201-
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
211+
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
202212
}
203213
} else { // run general case
204-
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
214+
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
205215
}
206216
}
207217

0 commit comments

Comments
 (0)