@@ -13,7 +13,7 @@ static __global__ void solve_tri_f32_fast(
1313 const float * __restrict__ A,
1414 const float * __restrict__ B,
1515 float * __restrict__ X,
16- const int64_t ne02,
16+ const uint3 ne02,
1717 const size_t nb02, const size_t nb03,
1818 const size_t nb12, const size_t nb13,
1919 const size_t nb2, const size_t nb3) {
@@ -26,8 +26,9 @@ static __global__ void solve_tri_f32_fast(
2626 return ;
2727 }
2828
29- const int64_t i03 = batch_idx / ne02;
30- const int64_t i02 = batch_idx % ne02;
29+ const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
30+ const int64_t i02 = i02_i03.y ;
31+ const int64_t i03 = i02_i03.x ;
3132
3233 const float * const A_batch = (const float *)((const char *)A + i02 * nb02 + i03 * nb03);
3334 const float * const B_batch = (const float *)((const char *)B + i02 * nb12 + i03 * nb13);
@@ -37,14 +38,19 @@ static __global__ void solve_tri_f32_fast(
3738 __shared__ float sA [MAX_N_FAST * MAX_N_FAST];
3839 __shared__ float sX [MAX_N_FAST * MAX_K_FAST];
3940
41+ const int offset = threadIdx .x + threadIdx .y * blockDim .x ;
4042 // Load A into shared memory (coalesced)
41- for (int i = threadIdx .x + threadIdx .y * blockDim .x ; i < n * n; i += blockDim .x * blockDim .y ) {
42- sA [i] = A_batch[i];
43+ #pragma unroll
44+ for (int i = 0 ; i < n * n; i += k * WARP_SIZE) {
45+ int i0 = i + offset;
46+ sA [i0] = A_batch[i0];
4347 }
4448
4549 // Load B into shared memory (coalesced)
46- for (int i = threadIdx .x + threadIdx .y * blockDim .x ; i < n * k; i += blockDim .x * blockDim .y ) {
47- sX [i] = B_batch[i];
50+ #pragma unroll
51+ for (int i = 0 ; i < n * k; i += k * WARP_SIZE) {
52+ int i0 = i + threadIdx .x + threadIdx .y * blockDim .x ;
53+ sX [i0] = B_batch[i0];
4854 }
4955 __syncthreads ();
5056
@@ -74,16 +80,18 @@ static __global__ void solve_tri_f32_fast(
7480 }
7581
7682 // Write results from shared memory to global memory (coalesced)
77- for (int i = threadIdx .x + threadIdx .y * blockDim .x ; i < n * k; i += blockDim .x * blockDim .y ) {
78- X_batch[i] = sX [i];
83+ #pragma unroll
84+ for (int i = 0 ; i < n * k; i += k * WARP_SIZE) {
85+ const int i0 = i + threadIdx .x + threadIdx .y *blockDim .x ;
86+ X_batch[i0] = sX [i0];
7987 }
8088}
8189
8290static __global__ void solve_tri_f32_fast_general (
8391 const float * __restrict__ A,
8492 const float * __restrict__ B,
8593 float * __restrict__ X,
86- const int64_t ne02,
94+ const uint3 ne02,
8795 const size_t nb02, const size_t nb03,
8896 const size_t nb12, const size_t nb13,
8997 const size_t nb2, const size_t nb3,
@@ -97,8 +105,9 @@ static __global__ void solve_tri_f32_fast_general(
97105 return ;
98106 }
99107
100- const int64_t i03 = batch_idx / ne02;
101- const int64_t i02 = batch_idx % ne02;
108+ const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
109+ const int64_t i02 = i02_i03.y ;
110+ const int64_t i03 = i02_i03.x ;
102111
103112 const float * const A_batch = (const float *)((const char *)A + i02 * nb02 + i03 * nb03);
104113 const float * const B_batch = (const float *)((const char *)B + i02 * nb12 + i03 * nb13);
@@ -164,44 +173,45 @@ static void solve_tri_f32_cuda(
164173 cudaStream_t stream)
165174{
166175 // n <= 64, k <= 32
176+ const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
167177 dim3 threads (WARP_SIZE, k);
168178 dim3 grid (ne02 * ne03);
169179 if (n == 64 ) {
170180 if (k == 32 ) {
171181 solve_tri_f32_fast<64 , 32 ><<<grid, threads, 0 , stream>>> (
172- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
182+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
173183 } else if (k == 16 ) {
174184 solve_tri_f32_fast<64 , 16 ><<<grid, threads, 0 , stream>>> (
175- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
185+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
176186 } else if (k == 14 ) {
177187 solve_tri_f32_fast<64 , 14 ><<<grid, threads, 0 , stream>>> (
178- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
188+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
179189 } else if (k == 12 ) {
180190 solve_tri_f32_fast<64 , 12 ><<<grid, threads, 0 , stream>>> (
181- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
191+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
182192 } else if (k == 10 ) {
183193 solve_tri_f32_fast<64 , 10 ><<<grid, threads, 0 , stream>>> (
184- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
194+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
185195 } else if (k == 8 ) {
186196 solve_tri_f32_fast<64 , 8 ><<<grid, threads, 0 , stream>>> (
187- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
197+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
188198 } else if (k == 6 ) {
189199 solve_tri_f32_fast<64 , 6 ><<<grid, threads, 0 , stream>>> (
190- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
200+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
191201 } else if (k == 4 ) {
192202 solve_tri_f32_fast<64 , 4 ><<<grid, threads, 0 , stream>>> (
193- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
203+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
194204 } else if (k == 2 ) {
195205 solve_tri_f32_fast<64 , 2 ><<<grid, threads, 0 , stream>>> (
196- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
206+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
197207 } else if (k == 1 ) {
198208 solve_tri_f32_fast<64 , 1 ><<<grid, threads, 0 , stream>>> (
199- A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3);
209+ A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3);
200210 } else {
201- solve_tri_f32_fast_general<<<grid, threads, 0 , stream>>> (A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3, n, k);
211+ solve_tri_f32_fast_general<<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3, n, k);
202212 }
203213 } else { // run general case
204- solve_tri_f32_fast_general<<<grid, threads, 0 , stream>>> (A, B, X, ne02 , nb02, nb03, nb12, nb13, nb2, nb3, n, k);
214+ solve_tri_f32_fast_general<<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd , nb02, nb03, nb12, nb13, nb2, nb3, n, k);
205215 }
206216}
207217
0 commit comments