1313namespace QuadBLAS
1414{
1515
16- // CORRECTED: Scalar micro-kernel for reliability
1716 inline void gemm_micro_kernel_scalar (size_t mr, size_t nr, size_t kc,
18- Sleef_quad alpha,
19- Sleef_quad *A_packed, Sleef_quad *B_packed,
20- Sleef_quad beta, Sleef_quad *C, size_t ldc)
17+ Sleef_quad alpha,
18+ Sleef_quad *A_packed, Sleef_quad *B_packed,
19+ Sleef_quad beta, Sleef_quad *C, size_t ldc)
2120 {
22- // A_packed: row-major, mr x kc (A_packed[row * kc + col])
23- // B_packed: row-major, kc x nr (B_packed[row * nr + col])
24- // C: original matrix with leading dimension ldc
25-
21+
2622 for (size_t i = 0 ; i < mr; ++i)
2723 {
2824 for (size_t j = 0 ; j < nr; ++j)
2925 {
3026 Sleef_quad sum = SLEEF_QUAD_C (0.0 );
31-
32- // Compute dot product: A_packed[i,:] • B_packed[:,j]
27+
3328 for (size_t k = 0 ; k < kc; ++k)
3429 {
35- Sleef_quad a_val = A_packed[i * kc + k]; // A[i][k]
36- Sleef_quad b_val = B_packed[k * nr + j]; // B[k][j]
37- sum = Sleef_fmaq1_u05 (a_val, b_val, sum); // sum += a * b
30+ Sleef_quad a_val = A_packed[i * kc + k];
31+ Sleef_quad b_val = B_packed[k * nr + j];
32+ sum = Sleef_fmaq1_u05 (a_val, b_val, sum);
3833 }
39-
40- // C[i][j] = alpha * sum + beta * C[i][j]
34+
4135 Sleef_quad c_old = C[i * ldc + j];
4236 C[i * ldc + j] = Sleef_fmaq1_u05 (alpha, sum, Sleef_mulq1_u05 (beta, c_old));
4337 }
4438 }
4539 }
4640
47- // CORRECTED: Vectorized micro-kernel (more complex but correct)
4841 inline void gemm_micro_kernel_vectorized (size_t mr, size_t nr, size_t kc,
49- Sleef_quad alpha,
50- Sleef_quad *A_packed, Sleef_quad *B_packed,
51- Sleef_quad beta, Sleef_quad *C, size_t ldc)
42+ Sleef_quad alpha,
43+ Sleef_quad *A_packed, Sleef_quad *B_packed,
44+ Sleef_quad beta, Sleef_quad *C, size_t ldc)
5245 {
53- // Only use vectorization if dimensions are suitable
46+
5447 if (mr % VECTOR_SIZE != 0 || nr % VECTOR_SIZE != 0 || mr < VECTOR_SIZE || nr < VECTOR_SIZE)
5548 {
5649 gemm_micro_kernel_scalar (mr, nr, kc, alpha, A_packed, B_packed, beta, C, ldc);
5750 return ;
5851 }
59-
52+
6053 const size_t mr_vec = mr / VECTOR_SIZE;
6154 const size_t nr_vec = nr / VECTOR_SIZE;
62-
63- // Use scalar accumulators for simplicity and correctness
55+
6456 Sleef_quad c_acc[mr][nr];
65-
66- // Initialize accumulators
57+
6758 for (size_t i = 0 ; i < mr; ++i)
6859 {
6960 for (size_t j = 0 ; j < nr; ++j)
7061 {
7162 c_acc[i][j] = SLEEF_QUAD_C (0.0 );
7263 }
7364 }
74-
75- // Main computation with vectorized inner loop when possible
65+
7666 for (size_t i = 0 ; i < mr; ++i)
7767 {
7868 for (size_t j_vec = 0 ; j_vec < nr_vec; ++j_vec)
7969 {
8070 size_t j_start = j_vec * VECTOR_SIZE;
8171 QuadVector sum_vec (SLEEF_QUAD_C (0.0 ));
82-
72+
8373 for (size_t k = 0 ; k < kc; ++k)
8474 {
8575 Sleef_quad a_val = A_packed[i * kc + k];
86- QuadVector a_vec (a_val); // Broadcast a_val to vector
87-
88- // Load VECTOR_SIZE consecutive B values from row k
76+ QuadVector a_vec (a_val);
77+
8978 QuadVector b_vec = QuadVector::load (&B_packed[k * nr + j_start]);
90-
91- // Accumulate: sum_vec += a_vec * b_vec
79+
9280 sum_vec = a_vec.fma (b_vec, sum_vec);
9381 }
94-
95- // Store back to scalar accumulators
82+
9683 for (size_t lane = 0 ; lane < VECTOR_SIZE; ++lane)
9784 {
9885 c_acc[i][j_start + lane] = sum_vec.get (lane);
9986 }
10087 }
101-
102- // Handle remaining columns with scalar code
88+
10389 for (size_t j = nr_vec * VECTOR_SIZE; j < nr; ++j)
10490 {
10591 for (size_t k = 0 ; k < kc; ++k)
@@ -108,8 +94,7 @@ namespace QuadBLAS
10894 }
10995 }
11096 }
111-
112- // Apply alpha and beta scaling and store to C
97+
11398 for (size_t i = 0 ; i < mr; ++i)
11499 {
115100 for (size_t j = 0 ; j < nr; ++j)
@@ -120,13 +105,12 @@ namespace QuadBLAS
120105 }
121106 }
122107
123- // Choose the best micro-kernel based on size
124108 inline void gemm_micro_kernel (size_t mr, size_t nr, size_t kc,
125109 Sleef_quad alpha,
126110 Sleef_quad *A_packed, Sleef_quad *B_packed,
127111 Sleef_quad beta, Sleef_quad *C, size_t ldc)
128112 {
129- // Use vectorized version for larger blocks, scalar for smaller/irregular sizes
113+
130114 if (mr >= VECTOR_SIZE && nr >= VECTOR_SIZE && (mr * nr >= 8 ))
131115 {
132116 gemm_micro_kernel_vectorized (mr, nr, kc, alpha, A_packed, B_packed, beta, C, ldc);
@@ -137,37 +121,10 @@ namespace QuadBLAS
137121 }
138122 }
139123
140- // CORRECTED: Macro-kernel for medium-sized blocks
141124 inline void gemm_macro_kernel (size_t mc, size_t nc, size_t kc,
142125 Sleef_quad alpha,
143126 Sleef_quad *A_packed, Sleef_quad *B_packed,
144127 Sleef_quad beta, Sleef_quad *C, size_t ldc)
145- {
146- constexpr size_t MR = 4 ; // Micro-panel height
147- constexpr size_t NR = 4 ; // Micro-panel width
148-
149- for (size_t i = 0 ; i < mc; i += MR)
150- {
151- size_t mr = std::min (MR, mc - i);
152-
153- for (size_t j = 0 ; j < nc; j += NR)
154- {
155- size_t nr = std::min (NR, nc - j);
156-
157- // FIXED: Correct pointers to packed matrices and C submatrix
158- gemm_micro_kernel (mr, nr, kc, alpha,
159- &A_packed[i * kc], // Start of rows i to i+mr-1 in A_packed
160- &B_packed[j], // Start of columns j to j+nr-1 in B_packed (but this is still wrong!)
161- beta, &C[i * ldc + j], ldc);
162- }
163- }
164- }
165-
166- // CORRECTED: Macro-kernel with proper B pointer calculation
167- inline void gemm_macro_kernel_fixed (size_t mc, size_t nc, size_t kc,
168- Sleef_quad alpha,
169- Sleef_quad *A_packed, Sleef_quad *B_packed,
170- Sleef_quad beta, Sleef_quad *C, size_t ldc)
171128 {
172129 constexpr size_t MR = 4 ;
173130 constexpr size_t NR = 4 ;
@@ -180,38 +137,35 @@ namespace QuadBLAS
180137 {
181138 size_t nr = std::min (NR, nc - j);
182139
183- // Create temporary B submatrix for this micro-kernel
184- // We need B[:,j:j+nr] but B_packed is row-major, so we need to extract columns
185140 Sleef_quad *B_sub = aligned_alloc<Sleef_quad>(kc * nr);
186141 if (B_sub)
187142 {
188- // Copy the required columns from B_packed
143+
189144 for (size_t k = 0 ; k < kc; ++k)
190145 {
191146 for (size_t jj = 0 ; jj < nr; ++jj)
192147 {
193148 B_sub[k * nr + jj] = B_packed[k * nc + (j + jj)];
194149 }
195150 }
196-
151+
197152 gemm_micro_kernel (mr, nr, kc, alpha,
198153 &A_packed[i * kc], B_sub,
199154 beta, &C[i * ldc + j], ldc);
200-
155+
201156 aligned_free (B_sub);
202157 }
203158 else
204159 {
205- // Fallback to scalar if allocation fails
160+
206161 gemm_micro_kernel_scalar (mr, nr, kc, alpha,
207- &A_packed[i * kc], &B_packed[j], // Note: this is still incorrect but safer
208- beta, &C[i * ldc + j], ldc);
162+ &A_packed[i * kc], &B_packed[j],
163+ beta, &C[i * ldc + j], ldc);
209164 }
210165 }
211166 }
212167 }
213168
214- // Simple GEMM implementation for small matrices (this was already correct)
215169 inline void gemm_simple (Layout layout, size_t m, size_t n, size_t k,
216170 Sleef_quad alpha,
217171 Sleef_quad *A, size_t lda,
@@ -240,7 +194,6 @@ namespace QuadBLAS
240194 }
241195 }
242196
243- // CORRECTED: Main GEMM function with safer blocked implementation
244197 inline void gemm (Layout layout, size_t m, size_t n, size_t k,
245198 Sleef_quad alpha,
246199 Sleef_quad *A, size_t lda,
@@ -250,8 +203,7 @@ namespace QuadBLAS
250203 if (m == 0 || n == 0 || k == 0 )
251204 return ;
252205
253- // Use simple implementation for small matrices OR when the blocked version might have issues
254- constexpr size_t SMALL_MATRIX_THRESHOLD = 64 ; // Increased threshold for safety
206+ constexpr size_t SMALL_MATRIX_THRESHOLD = 64 ;
255207 if (m <= SMALL_MATRIX_THRESHOLD && n <= SMALL_MATRIX_THRESHOLD && k <= SMALL_MATRIX_THRESHOLD)
256208 {
257209 gemm_simple (layout, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
@@ -260,7 +212,6 @@ namespace QuadBLAS
260212
261213 BlockingParams params (m, n, k);
262214
263- // Allocate temporary packed matrices
264215 Sleef_quad *A_packed = aligned_alloc<Sleef_quad>(params.mc * params.kc );
265216 Sleef_quad *B_packed = aligned_alloc<Sleef_quad>(params.kc * params.nc );
266217
@@ -272,7 +223,6 @@ namespace QuadBLAS
272223 return ;
273224 }
274225
275- // Blocked GEMM implementation
276226 for (size_t kk = 0 ; kk < k; kk += params.kc )
277227 {
278228 size_t kc = std::min (params.kc , k - kk);
@@ -281,7 +231,6 @@ namespace QuadBLAS
281231 {
282232 size_t mc = std::min (params.mc , m - mm);
283233
284- // Pack A panel (this was already correct)
285234 for (size_t i = 0 ; i < mc; ++i)
286235 {
287236 for (size_t j = 0 ; j < kc; ++j)
@@ -295,7 +244,6 @@ namespace QuadBLAS
295244 {
296245 size_t nc = std::min (params.nc , n - nn);
297246
298- // Pack B panel (this was already correct)
299247 for (size_t i = 0 ; i < kc; ++i)
300248 {
301249 for (size_t j = 0 ; j < nc; ++j)
@@ -305,14 +253,12 @@ namespace QuadBLAS
305253 }
306254 }
307255
308- // CORRECTED: Compute C block with proper matrix addressing
309256 Sleef_quad *C_block = &C[(layout == Layout::RowMajor) ? mm * ldc + nn : nn * ldc + mm];
310257
311- // Use the corrected macro-kernel
312- gemm_macro_kernel_fixed (mc, nc, kc, alpha,
313- A_packed, B_packed,
314- (kk == 0 ) ? beta : SLEEF_QUAD_C (1.0 ),
315- C_block, ldc);
258+ gemm_macro_kernel (mc, nc, kc, alpha,
259+ A_packed, B_packed,
260+ (kk == 0 ) ? beta : SLEEF_QUAD_C (1.0 ),
261+ C_block, ldc);
316262 }
317263 }
318264 }
@@ -321,6 +267,6 @@ namespace QuadBLAS
321267 aligned_free (B_packed);
322268 }
323269
324- } // namespace QuadBLAS
270+ }
325271
326- #endif // QUADBLAS_ALGORITHMS_LEVEL3_HPP
272+ #endif
0 commit comments