Skip to content

Commit 25d4ffa

Browse files
committed
feat: Metal_Dense uses Accelerate dense LU baseline (sgetrf/sgetrs)
Replace BaSpaCho supernodal Metal_Dense (impractical with 25K per-lump dispatches) with standalone Accelerate BLAS dense LU. Uses sgetrf for factorization and sgetrs for solve, with mixed-precision iterative refinement (float factor, double SpMV residual). C6288 (n=25380): factor ~10.3s, solve ~0.31s, residual ~1e-11. Demonstrates 880x benefit of sparse exploitation (Metal_Sparse: ~12ms). Size guard skips matrices >50K (dense n² would exceed memory). Also adds sgetrs/dgetrs declarations + LAPACKE wrappers to BlasDefs.h. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
1 parent 4330865 commit 25d4ffa

File tree

2 files changed

+164
-10
lines changed

2 files changed

+164
-10
lines changed

baspacho/baspacho/BlasDefs.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ void dgetrf_(const BLAS_INT* m, const BLAS_INT* n, double* A, const BLAS_INT* ld
5959
BLAS_INT* info);
6060
void sgetrf_(const BLAS_INT* m, const BLAS_INT* n, float* A, const BLAS_INT* lda, BLAS_INT* ipiv,
6161
BLAS_INT* info);
62+
63+
// LU solve (getrs) — solve A*X = B given LU factorization from getrf
64+
void dgetrs_(const char* trans, const BLAS_INT* n, const BLAS_INT* nrhs, const double* A,
65+
const BLAS_INT* lda, const BLAS_INT* ipiv, double* B, const BLAS_INT* ldb,
66+
BLAS_INT* info);
67+
void sgetrs_(const char* trans, const BLAS_INT* n, const BLAS_INT* nrhs, const float* A,
68+
const BLAS_INT* lda, const BLAS_INT* ipiv, float* B, const BLAS_INT* ldb,
69+
BLAS_INT* info);
6270
}
6371

6472
#define CBLAS_LAYOUT int
@@ -169,4 +177,19 @@ inline BLAS_INT LAPACKE_sgetrf(int /* matrix_layout */, BLAS_INT m, BLAS_INT n,
169177
return info;
170178
}
171179

180+
// LU solve wrappers
181+
inline BLAS_INT LAPACKE_dgetrs(char trans, BLAS_INT n, BLAS_INT nrhs, const double* a,
182+
BLAS_INT lda, const BLAS_INT* ipiv, double* b, BLAS_INT ldb) {
183+
BLAS_INT info;
184+
dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, &info);
185+
return info;
186+
}
187+
188+
inline BLAS_INT LAPACKE_sgetrs(char trans, BLAS_INT n, BLAS_INT nrhs, const float* a,
189+
BLAS_INT lda, const BLAS_INT* ipiv, float* b, BLAS_INT ldb) {
190+
BLAS_INT info;
191+
sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, &info);
192+
return info;
193+
}
194+
172195
} // end namespace BaSpaCho

baspacho/benchmarking/LUBench.cpp

Lines changed: 141 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
#include "baspacho/baspacho/MetalDefs.h"
4141
#endif
4242

43+
#ifdef BASPACHO_USE_BLAS
44+
#include "baspacho/baspacho/BlasDefs.h"
45+
#endif
46+
4347
using namespace BaSpaCho;
4448
using namespace BaSpaCho::testing_utils;
4549
using namespace std;
@@ -248,15 +252,11 @@ static vector<LUTimingResult> benchmarkLUCpu(
248252
// ============================================================================
249253

250254
#ifdef BASPACHO_USE_METAL
251-
// Metal LU benchmark: GPU sparse elimination + CPU BLAS dense + CPU SpMV refinement.
255+
// Metal LU benchmark (Metal_Sparse): GPU sparse elimination + CPU BLAS dense + CPU SpMV refinement.
252256
// Mirrors the spineax BaspachoGpuInstantiate/Execute FFI code path.
253257
// Uses persistent contexts, device-resident pivots, recording pass, and external encoder.
254-
//
255-
// useSparseElim=true (Metal_Sparse): GPU sparse elim for scalar lumps, CPU dense for rest.
256-
// useSparseElim=false (Metal_Dense): all-dense LU on GPU (no sparse elimination).
257258
static vector<LUTimingResult> benchmarkLUMetalFFI(
258-
const vector<pair<CsrMatrix, Eigen::VectorXd>>& matrices, int maxRefine,
259-
bool verbose, bool useSparseElim = true) {
259+
const vector<pair<CsrMatrix, Eigen::VectorXd>>& matrices, int maxRefine, bool verbose) {
260260
if (matrices.empty()) return {};
261261

262262
bool capturing = MetalContext::instance().beginCaptureIfRequested("/tmp/baspacho_ffi.gputrace");
@@ -313,7 +313,7 @@ static vector<LUTimingResult> benchmarkLUMetalFFI(
313313
settings.matrixType = MTYPE_GENERAL;
314314
settings.numThreads = 1;
315315
settings.staticPivotThreshold = pivotThreshold;
316-
settings.findSparseEliminationRanges = useSparseElim;
316+
settings.findSparseEliminationRanges = true;
317317

318318
vector<int64_t> paramSizes(n, 1);
319319
vector<int64_t> blockSizes(n, 1);
@@ -534,6 +534,132 @@ static vector<LUTimingResult> benchmarkLUMetalFFI(
534534
}
535535
#endif // BASPACHO_USE_METAL
536536

537+
// ============================================================================
538+
// Dense BLAS LU baseline (Accelerate sgetrf/sgetrs, no sparsity exploitation)
539+
// ============================================================================
540+
541+
#ifdef BASPACHO_USE_BLAS
542+
static vector<LUTimingResult> benchmarkLUDenseBLAS(
543+
const vector<pair<CsrMatrix, Eigen::VectorXd>>& matrices, int maxRefine, bool verbose) {
544+
if (matrices.empty()) return {};
545+
546+
const CsrMatrix& A0 = matrices[0].first;
547+
int64_t n = A0.nRows;
548+
549+
// Size guard: dense n×n float = n² × 4 bytes. n=50000 → ~10GB.
550+
if (n > 50000) {
551+
if (verbose)
552+
cout << " [DenseBLAS] Skipping: n=" << n << " too large for dense (" << (n * n * 4.0 / 1e9)
553+
<< " GB)" << endl;
554+
return {};
555+
}
556+
557+
// Preprocessing: BTF max transversal (once per pattern)
558+
auto preproc = computeMaxTransversal(n, A0.rowPtr.data(), A0.colInd.data());
559+
560+
vector<LUTimingResult> results;
561+
vector<int64_t> pRowPtr, pColInd;
562+
vector<double> pValues;
563+
564+
// Dense matrix (column-major for LAPACK) + pivot array
565+
vector<float> dense(n * n, 0.0f);
566+
vector<BLAS_INT> ipiv(n);
567+
568+
for (size_t mi = 0; mi < matrices.size(); mi++) {
569+
const CsrMatrix& A = matrices[mi].first;
570+
const Eigen::VectorXd& b = matrices[mi].second;
571+
LUTimingResult res;
572+
573+
// Equilibration (same as Metal_Sparse)
574+
applyRowPermToCsr<double>(n, A.rowPtr.data(), A.colInd.data(), A.values.data(),
575+
preproc.rowPerm.data(), pRowPtr, pColInd, pValues);
576+
vector<double> rowScale, colScale;
577+
computeEquilibration(n, pRowPtr.data(), pColInd.data(), pValues.data(), rowScale, colScale);
578+
579+
applyRowPermAndScaleToCsr<double>(n, A.rowPtr.data(), A.colInd.data(), A.values.data(),
580+
preproc.rowPerm.data(), rowScale.data(), colScale.data(),
581+
pRowPtr, pColInd, pValues);
582+
583+
// Scatter equilibrated sparse CSR → dense column-major: dense[col * n + row] = value
584+
fill(dense.begin(), dense.end(), 0.0f);
585+
for (int64_t i = 0; i < n; i++) {
586+
for (int64_t k = pRowPtr[i]; k < pRowPtr[i + 1]; k++) {
587+
dense[pColInd[k] * n + i] = float(pValues[k]);
588+
}
589+
}
590+
591+
// Factor: sgetrf
592+
auto tFactor = Clock::now();
593+
BLAS_INT N = static_cast<BLAS_INT>(n);
594+
BLAS_INT info = LAPACKE_sgetrf(LAPACK_COL_MAJOR, N, N, dense.data(), N, ipiv.data());
595+
res.factorTime = tdelta(Clock::now() - tFactor).count();
596+
597+
if (info != 0 && verbose) {
598+
cout << " [DenseBLAS] Matrix #" << mi << ": sgetrf info=" << info << endl;
599+
}
600+
601+
// Solve + iterative refinement
602+
auto tSolve = Clock::now();
603+
604+
// Initial solve: permute RHS, sgetrs, unscale
605+
vector<float> rhsF(n);
606+
for (int64_t j = 0; j < n; j++)
607+
rhsF[j] = float(rowScale[j] * b(preproc.rowPerm[j]));
608+
609+
char trans = 'N';
610+
BLAS_INT nrhs = 1;
611+
LAPACKE_sgetrs(trans, N, nrhs, dense.data(), N, ipiv.data(), rhsF.data(), N);
612+
613+
// Accumulate solution in double precision
614+
Eigen::VectorXd x(n);
615+
for (int64_t j = 0; j < n; j++)
616+
x(j) = colScale[j] * double(rhsF[j]);
617+
618+
double residual = computeResidualDouble(A, x, b);
619+
res.refineSteps = 0;
620+
621+
// Iterative refinement: CPU SpMV (double) → sgetrs (float) → accumulate (double)
622+
for (int iter = 0; iter < maxRefine && residual > 1e-10; iter++) {
623+
// SpMV residual in double precision
624+
Eigen::VectorXd r = Eigen::VectorXd::Zero(n);
625+
for (int64_t i = 0; i < n; i++) {
626+
for (int64_t k = A.rowPtr[i]; k < A.rowPtr[i + 1]; k++)
627+
r(i) += A.values[k] * x(A.colInd[k]);
628+
}
629+
r = b - r;
630+
631+
// Permute + scale residual → float RHS
632+
for (int64_t j = 0; j < n; j++)
633+
rhsF[j] = float(rowScale[j] * r(preproc.rowPerm[j]));
634+
635+
// Solve correction
636+
LAPACKE_sgetrs(trans, N, nrhs, dense.data(), N, ipiv.data(), rhsF.data(), N);
637+
638+
// Accumulate correction in double
639+
for (int64_t j = 0; j < n; j++)
640+
x(j) += colScale[j] * double(rhsF[j]);
641+
642+
residual = computeResidualDouble(A, x, b);
643+
res.refineSteps++;
644+
}
645+
646+
res.solveTime = tdelta(Clock::now() - tSolve).count();
647+
res.residual = residual;
648+
res.perturbCount = 0;
649+
650+
if (verbose) {
651+
cout << " [DenseBLAS] Matrix #" << mi << ": factor=" << fixed << setprecision(4)
652+
<< res.factorTime << "s, solve=" << res.solveTime << "s, residual=" << scientific
653+
<< setprecision(2) << res.residual << ", refine=" << res.refineSteps << endl;
654+
}
655+
656+
results.push_back(res);
657+
}
658+
659+
return results;
660+
}
661+
#endif // BASPACHO_USE_BLAS
662+
537663
// ============================================================================
538664
// CUDA (double) benchmark
539665
// ============================================================================
@@ -936,7 +1062,7 @@ void help() {
9361062
<< " BaSpaCho_LU_CPU\n"
9371063
#ifdef BASPACHO_USE_METAL
9381064
<< " Metal_Sparse (GPU sparse elim + CPU BLAS dense + CPU SpMV refinement)\n"
939-
<< " Metal_Dense (all-dense GPU LU, no sparse elimination)\n"
1065+
<< " Metal_Dense (Accelerate dense LU baseline, no sparsity exploitation)\n"
9401066
#endif
9411067
#ifdef BASPACHO_USE_CUBLAS
9421068
<< " BaSpaCho_LU_CUDA\n"
@@ -1144,15 +1270,20 @@ int main(int argc, char* argv[]) {
11441270
#ifdef BASPACHO_USE_METAL
11451271
if (regex_search(string("Metal_Sparse"), selectSolvers)) {
11461272
if (!jsonOutput) cout << "\nRunning Metal_Sparse..." << endl;
1147-
auto timings = benchmarkLUMetalFFI(matrices, maxRefineIters, verbose, true);
1273+
auto timings = benchmarkLUMetalFFI(matrices, maxRefineIters, verbose);
11481274
if (isWarmup && timings.size() > 1) timings.erase(timings.begin());
11491275
resultToRecords(problemName, "Metal_Sparse", timings, allRecords);
11501276
if (!jsonOutput) printResults("Metal_Sparse", timings);
11511277
}
11521278

11531279
if (regex_search(string("Metal_Dense"), selectSolvers)) {
11541280
if (!jsonOutput) cout << "\nRunning Metal_Dense..." << endl;
1155-
auto timings = benchmarkLUMetalFFI(matrices, maxRefineIters, verbose, false);
1281+
#ifdef BASPACHO_USE_BLAS
1282+
auto timings = benchmarkLUDenseBLAS(matrices, maxRefineIters, verbose);
1283+
#else
1284+
vector<LUTimingResult> timings;
1285+
if (!jsonOutput) cout << " [Metal_Dense] Skipping: BLAS not available" << endl;
1286+
#endif
11561287
if (isWarmup && timings.size() > 1) timings.erase(timings.begin());
11571288
resultToRecords(problemName, "Metal_Dense", timings, allRecords);
11581289
if (!jsonOutput) printResults("Metal_Dense", timings);

0 commit comments

Comments
 (0)