|
3 | 3 | #include "solve_tri.cuh" |
4 | 4 |
|
5 | 5 | #define MAX_N_FAST 64 |
6 | | -#define MAX_K_FAST 32 |
7 | 6 |
|
8 | 7 | // ====================== |
9 | 8 | // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction |
@@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, |
48 | 47 | float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); |
49 | 48 |
|
50 | 49 | __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; |
51 | | - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; |
52 | 50 |
|
53 | 51 | const int offset = threadIdx.x + threadIdx.y * blockDim.x; |
54 | 52 |
|
55 | 53 | #pragma unroll |
56 | 54 | for (int i = 0; i < n * n; i += k * WARP_SIZE) { |
57 | | - int i0 = i + offset; |
| 55 | + const int i0 = i + offset; |
58 | 56 | if (i0 < n * n) { |
59 | 57 | sA[i0] = A_batch[i0]; |
60 | 58 | } |
61 | 59 | } |
62 | 60 |
|
63 | | - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; |
| 61 | + __syncthreads(); |
64 | 62 |
|
65 | | -#pragma unroll |
66 | | - for (int i = 0; i < rows_per_warp; i++) { |
67 | | - const int i0 = lane + i * WARP_SIZE; |
68 | | - if (i0 < n) { |
69 | | - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; |
70 | | - } |
71 | | - } |
| 63 | + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; |
| 64 | + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; |
72 | 65 |
|
73 | | - __syncthreads(); |
| 66 | + const int half = WARP_SIZE; |
| 67 | + const int nrows_low = (n < half) ? n : half; |
74 | 68 |
|
75 | 69 | #pragma unroll |
76 | | - for (int row = 0; row < n; ++row) { |
| 70 | + for (int row = 0; row < nrows_low; ++row) { |
77 | 71 | float sum = 0.0f; |
78 | | - |
79 | | - { |
80 | | - int j = lane; |
81 | | - if (j < row) { |
82 | | - sum += sA[row * n + j] * sXt[col_idx * n + j]; |
83 | | - } |
| 72 | + if (lane < row) { |
| 73 | + sum += sA[row * n + lane] * x_low; |
84 | 74 | } |
85 | | - if (row >= WARP_SIZE) { |
86 | | - int j = WARP_SIZE + lane; |
87 | | - if (j < row) { |
88 | | - sum += sA[row * n + j] * sXt[col_idx * n + j]; |
89 | | - } |
| 75 | + sum = warp_reduce_sum(sum); |
| 76 | + |
| 77 | + if (lane == row) { |
| 78 | + x_low = (x_low - sum) / sA[row * n + row]; |
90 | 79 | } |
| 80 | + } |
91 | 81 |
|
| 82 | +#pragma unroll |
| 83 | + for (int row = half; row < n; ++row) { |
| 84 | + float sum = sA[row * n + lane] * x_low; |
| 85 | + const int j = half + lane; |
| 86 | + if (j < row) { |
| 87 | + sum += sA[row * n + j] * x_high; |
| 88 | + } |
92 | 89 | sum = warp_reduce_sum(sum); |
93 | 90 |
|
94 | | - if (lane == 0) { |
95 | | - const float b_val = sXt[col_idx * n + row]; |
96 | | - const float a_diag = sA[row * n + row]; |
97 | | - // no safeguards for division by zero because that indicates corrupt |
98 | | - // data anyway |
99 | | - sXt[col_idx * n + row] = (b_val - sum) / a_diag; |
| 91 | + if (lane == row - half) { |
| 92 | + x_high = (x_high - sum) / sA[row * n + row]; |
100 | 93 | } |
101 | 94 | } |
102 | 95 |
|
103 | | - __syncthreads(); |
104 | | - |
105 | 96 | #pragma unroll |
106 | | - for (int i = 0; i < rows_per_warp; i++) { |
107 | | - const int i0 = lane + i * WARP_SIZE; |
108 | | - if (i0 < n) { |
109 | | - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; |
| 97 | + for (int rr = 0; rr < 2; ++rr) { |
| 98 | + const int row = rr * WARP_SIZE + lane; |
| 99 | + if (row < n) { |
| 100 | + const float val = (row < half) ? x_low : x_high; |
| 101 | + X_batch[row * k + col_idx] = val; |
110 | 102 | } |
111 | 103 | } |
112 | 104 | } |
|
0 commit comments