Skip to content

Commit 0eabb67

Browse files
committed
small refactor
1 parent 4853ac1 commit 0eabb67

File tree

1 file changed

+38
-92
lines changed

1 file changed

+38
-92
lines changed

include/quadblas/algorithms/level3.hpp

Lines changed: 38 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,93 +13,79 @@
1313
namespace QuadBLAS
1414
{
1515

16-
// CORRECTED: Scalar micro-kernel for reliability
1716
inline void gemm_micro_kernel_scalar(size_t mr, size_t nr, size_t kc,
18-
Sleef_quad alpha,
19-
Sleef_quad *A_packed, Sleef_quad *B_packed,
20-
Sleef_quad beta, Sleef_quad *C, size_t ldc)
17+
Sleef_quad alpha,
18+
Sleef_quad *A_packed, Sleef_quad *B_packed,
19+
Sleef_quad beta, Sleef_quad *C, size_t ldc)
2120
{
22-
// A_packed: row-major, mr x kc (A_packed[row * kc + col])
23-
// B_packed: row-major, kc x nr (B_packed[row * nr + col])
24-
// C: original matrix with leading dimension ldc
25-
21+
2622
for (size_t i = 0; i < mr; ++i)
2723
{
2824
for (size_t j = 0; j < nr; ++j)
2925
{
3026
Sleef_quad sum = SLEEF_QUAD_C(0.0);
31-
32-
// Compute dot product: A_packed[i,:] • B_packed[:,j]
27+
3328
for (size_t k = 0; k < kc; ++k)
3429
{
35-
Sleef_quad a_val = A_packed[i * kc + k]; // A[i][k]
36-
Sleef_quad b_val = B_packed[k * nr + j]; // B[k][j]
37-
sum = Sleef_fmaq1_u05(a_val, b_val, sum); // sum += a * b
30+
Sleef_quad a_val = A_packed[i * kc + k];
31+
Sleef_quad b_val = B_packed[k * nr + j];
32+
sum = Sleef_fmaq1_u05(a_val, b_val, sum);
3833
}
39-
40-
// C[i][j] = alpha * sum + beta * C[i][j]
34+
4135
Sleef_quad c_old = C[i * ldc + j];
4236
C[i * ldc + j] = Sleef_fmaq1_u05(alpha, sum, Sleef_mulq1_u05(beta, c_old));
4337
}
4438
}
4539
}
4640

47-
// CORRECTED: Vectorized micro-kernel (more complex but correct)
4841
inline void gemm_micro_kernel_vectorized(size_t mr, size_t nr, size_t kc,
49-
Sleef_quad alpha,
50-
Sleef_quad *A_packed, Sleef_quad *B_packed,
51-
Sleef_quad beta, Sleef_quad *C, size_t ldc)
42+
Sleef_quad alpha,
43+
Sleef_quad *A_packed, Sleef_quad *B_packed,
44+
Sleef_quad beta, Sleef_quad *C, size_t ldc)
5245
{
53-
// Only use vectorization if dimensions are suitable
46+
5447
if (mr % VECTOR_SIZE != 0 || nr % VECTOR_SIZE != 0 || mr < VECTOR_SIZE || nr < VECTOR_SIZE)
5548
{
5649
gemm_micro_kernel_scalar(mr, nr, kc, alpha, A_packed, B_packed, beta, C, ldc);
5750
return;
5851
}
59-
52+
6053
const size_t mr_vec = mr / VECTOR_SIZE;
6154
const size_t nr_vec = nr / VECTOR_SIZE;
62-
63-
// Use scalar accumulators for simplicity and correctness
55+
6456
Sleef_quad c_acc[mr][nr];
65-
66-
// Initialize accumulators
57+
6758
for (size_t i = 0; i < mr; ++i)
6859
{
6960
for (size_t j = 0; j < nr; ++j)
7061
{
7162
c_acc[i][j] = SLEEF_QUAD_C(0.0);
7263
}
7364
}
74-
75-
// Main computation with vectorized inner loop when possible
65+
7666
for (size_t i = 0; i < mr; ++i)
7767
{
7868
for (size_t j_vec = 0; j_vec < nr_vec; ++j_vec)
7969
{
8070
size_t j_start = j_vec * VECTOR_SIZE;
8171
QuadVector sum_vec(SLEEF_QUAD_C(0.0));
82-
72+
8373
for (size_t k = 0; k < kc; ++k)
8474
{
8575
Sleef_quad a_val = A_packed[i * kc + k];
86-
QuadVector a_vec(a_val); // Broadcast a_val to vector
87-
88-
// Load VECTOR_SIZE consecutive B values from row k
76+
QuadVector a_vec(a_val);
77+
8978
QuadVector b_vec = QuadVector::load(&B_packed[k * nr + j_start]);
90-
91-
// Accumulate: sum_vec += a_vec * b_vec
79+
9280
sum_vec = a_vec.fma(b_vec, sum_vec);
9381
}
94-
95-
// Store back to scalar accumulators
82+
9683
for (size_t lane = 0; lane < VECTOR_SIZE; ++lane)
9784
{
9885
c_acc[i][j_start + lane] = sum_vec.get(lane);
9986
}
10087
}
101-
102-
// Handle remaining columns with scalar code
88+
10389
for (size_t j = nr_vec * VECTOR_SIZE; j < nr; ++j)
10490
{
10591
for (size_t k = 0; k < kc; ++k)
@@ -108,8 +94,7 @@ namespace QuadBLAS
10894
}
10995
}
11096
}
111-
112-
// Apply alpha and beta scaling and store to C
97+
11398
for (size_t i = 0; i < mr; ++i)
11499
{
115100
for (size_t j = 0; j < nr; ++j)
@@ -120,13 +105,12 @@ namespace QuadBLAS
120105
}
121106
}
122107

123-
// Choose the best micro-kernel based on size
124108
inline void gemm_micro_kernel(size_t mr, size_t nr, size_t kc,
125109
Sleef_quad alpha,
126110
Sleef_quad *A_packed, Sleef_quad *B_packed,
127111
Sleef_quad beta, Sleef_quad *C, size_t ldc)
128112
{
129-
// Use vectorized version for larger blocks, scalar for smaller/irregular sizes
113+
130114
if (mr >= VECTOR_SIZE && nr >= VECTOR_SIZE && (mr * nr >= 8))
131115
{
132116
gemm_micro_kernel_vectorized(mr, nr, kc, alpha, A_packed, B_packed, beta, C, ldc);
@@ -137,37 +121,10 @@ namespace QuadBLAS
137121
}
138122
}
139123

140-
// CORRECTED: Macro-kernel for medium-sized blocks
141124
inline void gemm_macro_kernel(size_t mc, size_t nc, size_t kc,
142125
Sleef_quad alpha,
143126
Sleef_quad *A_packed, Sleef_quad *B_packed,
144127
Sleef_quad beta, Sleef_quad *C, size_t ldc)
145-
{
146-
constexpr size_t MR = 4; // Micro-panel height
147-
constexpr size_t NR = 4; // Micro-panel width
148-
149-
for (size_t i = 0; i < mc; i += MR)
150-
{
151-
size_t mr = std::min(MR, mc - i);
152-
153-
for (size_t j = 0; j < nc; j += NR)
154-
{
155-
size_t nr = std::min(NR, nc - j);
156-
157-
// FIXED: Correct pointers to packed matrices and C submatrix
158-
gemm_micro_kernel(mr, nr, kc, alpha,
159-
&A_packed[i * kc], // Start of rows i to i+mr-1 in A_packed
160-
&B_packed[j], // Start of columns j to j+nr-1 in B_packed (but this is still wrong!)
161-
beta, &C[i * ldc + j], ldc);
162-
}
163-
}
164-
}
165-
166-
// CORRECTED: Macro-kernel with proper B pointer calculation
167-
inline void gemm_macro_kernel_fixed(size_t mc, size_t nc, size_t kc,
168-
Sleef_quad alpha,
169-
Sleef_quad *A_packed, Sleef_quad *B_packed,
170-
Sleef_quad beta, Sleef_quad *C, size_t ldc)
171128
{
172129
constexpr size_t MR = 4;
173130
constexpr size_t NR = 4;
@@ -180,38 +137,35 @@ namespace QuadBLAS
180137
{
181138
size_t nr = std::min(NR, nc - j);
182139

183-
// Create temporary B submatrix for this micro-kernel
184-
// We need B[:,j:j+nr] but B_packed is row-major, so we need to extract columns
185140
Sleef_quad *B_sub = aligned_alloc<Sleef_quad>(kc * nr);
186141
if (B_sub)
187142
{
188-
// Copy the required columns from B_packed
143+
189144
for (size_t k = 0; k < kc; ++k)
190145
{
191146
for (size_t jj = 0; jj < nr; ++jj)
192147
{
193148
B_sub[k * nr + jj] = B_packed[k * nc + (j + jj)];
194149
}
195150
}
196-
151+
197152
gemm_micro_kernel(mr, nr, kc, alpha,
198153
&A_packed[i * kc], B_sub,
199154
beta, &C[i * ldc + j], ldc);
200-
155+
201156
aligned_free(B_sub);
202157
}
203158
else
204159
{
205-
// Fallback to scalar if allocation fails
160+
206161
gemm_micro_kernel_scalar(mr, nr, kc, alpha,
207-
&A_packed[i * kc], &B_packed[j], // Note: this is still incorrect but safer
208-
beta, &C[i * ldc + j], ldc);
162+
&A_packed[i * kc], &B_packed[j],
163+
beta, &C[i * ldc + j], ldc);
209164
}
210165
}
211166
}
212167
}
213168

214-
// Simple GEMM implementation for small matrices (this was already correct)
215169
inline void gemm_simple(Layout layout, size_t m, size_t n, size_t k,
216170
Sleef_quad alpha,
217171
Sleef_quad *A, size_t lda,
@@ -240,7 +194,6 @@ namespace QuadBLAS
240194
}
241195
}
242196

243-
// CORRECTED: Main GEMM function with safer blocked implementation
244197
inline void gemm(Layout layout, size_t m, size_t n, size_t k,
245198
Sleef_quad alpha,
246199
Sleef_quad *A, size_t lda,
@@ -250,8 +203,7 @@ namespace QuadBLAS
250203
if (m == 0 || n == 0 || k == 0)
251204
return;
252205

253-
// Use simple implementation for small matrices OR when the blocked version might have issues
254-
constexpr size_t SMALL_MATRIX_THRESHOLD = 64; // Increased threshold for safety
206+
constexpr size_t SMALL_MATRIX_THRESHOLD = 64;
255207
if (m <= SMALL_MATRIX_THRESHOLD && n <= SMALL_MATRIX_THRESHOLD && k <= SMALL_MATRIX_THRESHOLD)
256208
{
257209
gemm_simple(layout, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
@@ -260,7 +212,6 @@ namespace QuadBLAS
260212

261213
BlockingParams params(m, n, k);
262214

263-
// Allocate temporary packed matrices
264215
Sleef_quad *A_packed = aligned_alloc<Sleef_quad>(params.mc * params.kc);
265216
Sleef_quad *B_packed = aligned_alloc<Sleef_quad>(params.kc * params.nc);
266217

@@ -272,7 +223,6 @@ namespace QuadBLAS
272223
return;
273224
}
274225

275-
// Blocked GEMM implementation
276226
for (size_t kk = 0; kk < k; kk += params.kc)
277227
{
278228
size_t kc = std::min(params.kc, k - kk);
@@ -281,7 +231,6 @@ namespace QuadBLAS
281231
{
282232
size_t mc = std::min(params.mc, m - mm);
283233

284-
// Pack A panel (this was already correct)
285234
for (size_t i = 0; i < mc; ++i)
286235
{
287236
for (size_t j = 0; j < kc; ++j)
@@ -295,7 +244,6 @@ namespace QuadBLAS
295244
{
296245
size_t nc = std::min(params.nc, n - nn);
297246

298-
// Pack B panel (this was already correct)
299247
for (size_t i = 0; i < kc; ++i)
300248
{
301249
for (size_t j = 0; j < nc; ++j)
@@ -305,14 +253,12 @@ namespace QuadBLAS
305253
}
306254
}
307255

308-
// CORRECTED: Compute C block with proper matrix addressing
309256
Sleef_quad *C_block = &C[(layout == Layout::RowMajor) ? mm * ldc + nn : nn * ldc + mm];
310257

311-
// Use the corrected macro-kernel
312-
gemm_macro_kernel_fixed(mc, nc, kc, alpha,
313-
A_packed, B_packed,
314-
(kk == 0) ? beta : SLEEF_QUAD_C(1.0),
315-
C_block, ldc);
258+
gemm_macro_kernel(mc, nc, kc, alpha,
259+
A_packed, B_packed,
260+
(kk == 0) ? beta : SLEEF_QUAD_C(1.0),
261+
C_block, ldc);
316262
}
317263
}
318264
}
@@ -321,6 +267,6 @@ namespace QuadBLAS
321267
aligned_free(B_packed);
322268
}
323269

324-
} // namespace QuadBLAS
270+
}
325271

326-
#endif // QUADBLAS_ALGORITHMS_LEVEL3_HPP
272+
#endif

0 commit comments

Comments
 (0)