Skip to content

Commit c26dc30

Browse files
committed
feat: GPU-resident prepareAssemble + recording mode for Metal backend
- Add prepareAssemble_kernel_float Metal compute shader that reads device-resident skeleton arrays (chainColPtr, chainRowSpan, chainData), replacing the CPU loop + memcpy in MetalNumericCtx::prepareAssemble(). Eliminates a CPU→GPU sync point per lump. - Implement beginRecording()/endRecording() in MetalNumericCtx<float>, mirroring the CUDA implementation. Records LUGemmWorkItem dispatch schedule + flush boundaries at init time, replays from pre-computed MetalMirror buffer during subsequent factorizations. Eliminates per-lump CPU memcpy of work items in flushPendingGemms(). - Add recordingMode_ guards to all GPU-executing methods (getrf, trsm, applyRowPerm, prepareAssemble, assemble, doElimination*, potrf, trsm, saveSyrkGemm, maxAbsDiag, perturbSmallDiagonals, beginDenseOps, flush, flushDevicePivots, deferredPerturbCount). - Update reset() to reset precomputedFlushIdx_ while preserving pre-computed buffers across factorizations. All 5 MetalLUTest tests pass. lu_bench correctness verified with ring Jacobian (47x47, median 0.75ms factor, 0.59ms solve). Co-developed-by: Claude Code (claude-opus-4-6)
1 parent bfd5962 commit c26dc30

File tree

2 files changed

+170
-9
lines changed

2 files changed

+170
-9
lines changed

baspacho/baspacho/MatOpsMetal.mm

Lines changed: 151 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ void collectDeferredElimPerturb() {
687687

688688
// Signal start of dense LU operations — wait for deferred sparse elim GPU work.
689689
void beginDenseOps(float* data, int64_t totalDataSize) override {
690+
if (recordingMode_) return; // no-op during recording
690691
(void)data;
691692
(void)totalDataSize;
692693
waitForGpu();
@@ -706,6 +707,46 @@ void commitAndWait() {
706707
// devGemmWorkBuf_, avoiding the need to wait for GPU to finish reading
707708
// previous data. Buffer reset happens at flush() time between factorizations.
708709
void flushPendingGemms() {
710+
if (recordingMode_) {
711+
// Recording mode: record flush point, skip GPU work
712+
if (recordingBatchCount_ > 0) {
713+
size_t startIdx = recordedItems_.size() - recordingBatchCount_;
714+
recordedFlushPoints_.push_back({startIdx, recordingBatchCount_});
715+
recordingBatchCount_ = 0;
716+
}
717+
pendingGemms_.clear();
718+
return;
719+
}
720+
721+
if (usePrecomputed_) {
722+
// Pre-computed mode: dispatch from device-resident items
723+
if (precomputedFlushIdx_ >= recordedFlushPoints_.size()) return;
724+
auto [startIdx, count] = recordedFlushPoints_[precomputedFlushIdx_];
725+
precomputedFlushIdx_++;
726+
if (count == 0) return;
727+
728+
// Byte offset into pre-computed buffer (MetalMirror backing is already aligned)
729+
size_t byteOffset = startIdx * sizeof(LUGemmWorkItem);
730+
731+
int64_t countI64 = (int64_t)count;
732+
733+
id<MTLComputePipelineState> pipeline = getProfiledPipeline(
734+
"lu_batchedSaveGemm_kernel_float");
735+
736+
encodeKernel(
737+
pipeline,
738+
^(id<MTLComputeCommandEncoder> encoder) {
739+
[encoder setBuffer:cachedDataBuffer_ offset:0 atIndex:0];
740+
[encoder setBuffer:(__bridge id<MTLBuffer>)devPrecomputedItems_.buffer()
741+
offset:byteOffset
742+
atIndex:1];
743+
[encoder setBytes:&countI64 length:sizeof(int64_t) atIndex:2];
744+
},
745+
(NSUInteger)count);
746+
pendingGemms_.clear();
747+
return;
748+
}
749+
709750
if (pendingGemms_.empty()) return;
710751

711752
int64_t count = (int64_t)pendingGemms_.size();
@@ -756,6 +797,7 @@ void flushPendingGemms() {
756797
}
757798

758799
virtual void pseudoFactorSpans(float* data, int64_t spanBegin, int64_t spanEnd) override {
800+
if (recordingMode_) return;
759801
@autoreleasepool {
760802
// Find the MTLBuffer for data
761803
auto bufferInfo = MetalBufferRegistry::instance().findBuffer(data);
@@ -805,6 +847,7 @@ virtual void pseudoFactorSpans(float* data, int64_t spanBegin, int64_t spanEnd)
805847

806848
virtual void doElimination(const SymElimCtx& elimData, float* data, int64_t lumpsBegin,
807849
int64_t lumpsEnd) override {
850+
if (recordingMode_) return;
808851
@autoreleasepool {
809852
const MetalSymElimCtx* pElim = dynamic_cast<const MetalSymElimCtx*>(&elimData);
810853
BASPACHO_CHECK_NOTNULL(pElim);
@@ -903,6 +946,7 @@ virtual void doElimination(const SymElimCtx& elimData, float* data, int64_t lump
903946
virtual void doEliminationLU(const SymElimCtx& elimData, float* data, int64_t lumpsBegin,
904947
int64_t lumpsEnd, float staticPivotThreshold,
905948
int64_t& perturbCount) override {
949+
if (recordingMode_) return;
906950
@autoreleasepool {
907951
const MetalSymElimCtx* pElim = dynamic_cast<const MetalSymElimCtx*>(&elimData);
908952
BASPACHO_CHECK_NOTNULL(pElim);
@@ -1018,6 +1062,7 @@ void doAllEliminationsLU(const std::vector<SymElimCtxPtr>& elimCtxs,
10181062
const std::vector<int64_t>& ranges, float* data,
10191063
float staticPivotThreshold,
10201064
int64_t& totalPerturbCount) override {
1065+
if (recordingMode_) return;
10211066
@autoreleasepool {
10221067
auto bufferInfo = MetalBufferRegistry::instance().findBuffer(data);
10231068
if (!bufferInfo.first) {
@@ -1160,6 +1205,7 @@ void doAllEliminationsLU(const std::vector<SymElimCtxPtr>& elimCtxs,
11601205
// Phase 2: segmented sum per target in fixed order (deterministic)
11611206
void doAllEliminations(const std::vector<SymElimCtxPtr>& elimCtxs,
11621207
const std::vector<int64_t>& ranges, float* data) override {
1208+
if (recordingMode_) return;
11631209
@autoreleasepool {
11641210
auto bufferInfo = MetalBufferRegistry::instance().findBuffer(data);
11651211
if (!bufferInfo.first) {
@@ -1286,6 +1332,7 @@ void doAllEliminations(const std::vector<SymElimCtxPtr>& elimCtxs,
12861332

12871333
virtual double maxAbsDiag(const float* data, const int64_t* lumpStart, const int64_t* chainColPtr,
12881334
const int64_t* chainData, int64_t startLump, int64_t upToLump) override {
1335+
if (recordingMode_) return 0.0;
12891336
@autoreleasepool {
12901337
int64_t numLumps = upToLump - startLump;
12911338
if (numLumps <= 0) return 0.0;
@@ -1332,6 +1379,7 @@ virtual double maxAbsDiag(const float* data, const int64_t* lumpStart, const int
13321379
}
13331380

13341381
virtual void potrf(int64_t n, float* data, int64_t offA) override {
1382+
if (recordingMode_) return;
13351383
@autoreleasepool {
13361384
if (n <= 0) return;
13371385

@@ -1387,6 +1435,7 @@ virtual void potrf(int64_t n, float* data, int64_t offA) override {
13871435
}
13881436

13891437
virtual void trsm(int64_t n, int64_t k, float* data, int64_t offA, int64_t offB) override {
1438+
if (recordingMode_) return;
13901439
@autoreleasepool {
13911440
if (n <= 0 || k <= 0) return;
13921441

@@ -1449,6 +1498,7 @@ virtual void trsm(int64_t n, int64_t k, float* data, int64_t offA, int64_t offB)
14491498

14501499
virtual void saveSyrkGemm(int64_t m, int64_t n, int64_t k, const float* data,
14511500
int64_t offset) override {
1501+
if (recordingMode_) return;
14521502
@autoreleasepool {
14531503
if (m <= 0 || n <= 0 || k <= 0) return;
14541504

@@ -1528,6 +1578,8 @@ virtual void saveSyrkGemm(int64_t m, int64_t n, int64_t k, const float* data,
15281578
}
15291579

15301580
virtual void prepareAssemble(int64_t targetLump) override {
1581+
if (recordingMode_) return; // no-op during recording
1582+
15311583
// Only flush if assemble() was actually called since last prepareAssemble.
15321584
// For LU factorization (isGeneral()==true), eliminateBoardLU only calls
15331585
// saveGemm — never assemble — so flushing is unnecessary and avoiding it
@@ -1537,22 +1589,36 @@ virtual void prepareAssemble(int64_t targetLump) override {
15371589
assembleWasCalled_ = false;
15381590
}
15391591

1540-
// Prepare chain offset mapping for assembly (same as CUDA version)
1592+
// GPU kernel: reads device-resident skeleton arrays, writes devSpanToChainOffset.
1593+
// All inputs (devChainColPtr, devChainRowSpan, devChainData) are already on device
1594+
// in MetalSymbolicCtx. No CPU loop, no memcpy — eliminates CPU→GPU sync point.
15411595
const CoalescedBlockMatrixSkel& skel = sym.skel;
1596+
int64_t numEntries = skel.chainColPtr[targetLump + 1] - skel.chainColPtr[targetLump];
1597+
if (numEntries <= 0) return;
15421598

1543-
for (int64_t i = skel.chainColPtr[targetLump], iEnd = skel.chainColPtr[targetLump + 1]; i < iEnd;
1544-
i++) {
1545-
spanToChainOffset[skel.chainRowSpan[i]] = skel.chainData[i];
1546-
}
1599+
id<MTLComputePipelineState> pipeline = getProfiledPipeline(
1600+
"prepareAssemble_kernel_float");
15471601

1548-
// Copy to device
1549-
memcpy(devSpanToChainOffset.ptr(), spanToChainOffset.data(),
1550-
spanToChainOffset.size() * sizeof(int64_t));
1602+
encodeKernel(
1603+
pipeline,
1604+
^(id<MTLComputeCommandEncoder> encoder) {
1605+
[encoder setBuffer:(__bridge id<MTLBuffer>)sym.devChainColPtr.buffer()
1606+
offset:0 atIndex:0];
1607+
[encoder setBuffer:(__bridge id<MTLBuffer>)sym.devChainRowSpan.buffer()
1608+
offset:0 atIndex:1];
1609+
[encoder setBuffer:(__bridge id<MTLBuffer>)sym.devChainData.buffer()
1610+
offset:0 atIndex:2];
1611+
[encoder setBuffer:(__bridge id<MTLBuffer>)devSpanToChainOffset.buffer()
1612+
offset:0 atIndex:3];
1613+
[encoder setBytes:&targetLump length:sizeof(int64_t) atIndex:4];
1614+
},
1615+
(NSUInteger)numEntries);
15511616
}
15521617

15531618
virtual void assemble(float* data, int64_t rectRowBegin, int64_t dstStride,
15541619
int64_t srcColDataOffset, int64_t srcRectWidth, int64_t numBlockRows,
15551620
int64_t numBlockCols) override {
1621+
if (recordingMode_) return;
15561622
@autoreleasepool {
15571623
if (numBlockRows <= 0 || numBlockCols <= 0) return;
15581624
assembleWasCalled_ = true;
@@ -1636,6 +1702,7 @@ virtual void assemble(float* data, int64_t rectRowBegin, int64_t dstStride,
16361702
// CPU path for dense ops, GPU path for deferred execution.
16371703
virtual int64_t perturbSmallDiagonals(int64_t n, float* data, int64_t offset, int64_t stride,
16381704
float threshold) override {
1705+
if (recordingMode_) return 0;
16391706
@autoreleasepool {
16401707
if (n <= 0) return 0;
16411708

@@ -1695,6 +1762,7 @@ virtual int64_t perturbSmallDiagonals(int64_t n, float* data, int64_t offset, in
16951762
// ============ LU factorization methods ============
16961763

16971764
virtual int getrf(int64_t m, int64_t n, float* data, int64_t offA, int64_t* pivots) override {
1765+
if (recordingMode_) { flushPendingGemms(); return 0; }
16981766
@autoreleasepool {
16991767
if (m <= 0 || n <= 0) return 0;
17001768

@@ -1807,6 +1875,7 @@ virtual int getrf(int64_t m, int64_t n, float* data, int64_t offA, int64_t* pivo
18071875

18081876
virtual void trsmLowerUnit(int64_t m, int64_t n, const float* L, int64_t offL, float* B,
18091877
int64_t offB, int64_t ldb) override {
1878+
if (recordingMode_) return;
18101879
@autoreleasepool {
18111880
if (m <= 0 || n <= 0) return;
18121881

@@ -1857,6 +1926,7 @@ virtual void trsmLowerUnit(int64_t m, int64_t n, const float* L, int64_t offL, f
18571926

18581927
virtual void trsmUpperRight(int64_t m, int64_t n, const float* U, int64_t offU, float* B,
18591928
int64_t offB, int64_t ldb) override {
1929+
if (recordingMode_) return;
18601930
@autoreleasepool {
18611931
if (m <= 0 || n <= 0) return;
18621932

@@ -1977,13 +2047,26 @@ virtual void saveGemm(int64_t m, int64_t n, int64_t k, const float* L, int64_t o
19772047
item.n = n;
19782048
item.k = k;
19792049

2050+
if (recordingMode_) {
2051+
recordedItems_.push_back(item);
2052+
recordingBatchCount_++;
2053+
sym.luGemmCalls++;
2054+
return;
2055+
}
2056+
if (usePrecomputed_) {
2057+
// Items already on device — dispatched from pre-computed buffer in flushPendingGemms
2058+
sym.luGemmCalls++;
2059+
return;
2060+
}
2061+
19802062
pendingGemms_.push_back(item);
19812063
sym.luGemmCalls++;
19822064
}
19832065
}
19842066

19852067
virtual void applyRowPerm(int64_t* pivots, int64_t n, float* data, int64_t offData, int64_t ld,
19862068
int64_t numCols) override {
2069+
if (recordingMode_) return;
19872070
@autoreleasepool {
19882071
if (n <= 0 || numCols <= 0) return;
19892072

@@ -2048,6 +2131,8 @@ virtual void applyRowPerm(int64_t* pivots, int64_t n, float* data, int64_t offDa
20482131
}
20492132

20502133
void flush() override {
2134+
flushPendingGemms();
2135+
if (recordingMode_) return; // skip pivot copies during recording
20512136
commitAndWait();
20522137

20532138
// Reset batched state for next factorization
@@ -2078,8 +2163,12 @@ void reset() override {
20782163
deferredElimPerturbBuf_ = nil;
20792164
assembleWasCalled_ = false;
20802165
potrfStatusPending_ = false;
2166+
// Reset pre-computed dispatch index (buffers persist across calls)
2167+
if (usePrecomputed_) {
2168+
precomputedFlushIdx_ = 0;
2169+
}
20812170
// Buffers (tempBuffer, devSpanToChainOffset, devPivots, devAllPivots,
2082-
// devGemmWorkBuf_, perturbCountBuf_) are NOT freed — reused across calls.
2171+
// devGemmWorkBuf_, perturbCountBuf_, devPrecomputedItems_) are NOT freed — reused across calls.
20832172
}
20842173

20852174
// Pre-allocate all Metal buffers to max needed sizes so no allocation occurs
@@ -2112,6 +2201,7 @@ void preAllocateForLU(int64_t maxDenseBlockSize, int64_t totalDensePivots) overr
21122201
// backing stores — both are CPU-accessible, so no D->H->D round-trip.
21132202
void flushDevicePivots(int64_t* devDstPivots) override {
21142203
flushPendingGemms();
2204+
if (recordingMode_) return;
21152205
if (pivotsOnGpu_ && allPivotsCount_ > 0) {
21162206
// On Metal unified memory, devAllPivots.ptr() and devDstPivots are both
21172207
// CPU-accessible pointers to Metal buffer backing stores.
@@ -2123,6 +2213,7 @@ void flushDevicePivots(int64_t* devDstPivots) override {
21232213
}
21242214

21252215
int64_t deferredPerturbCount() override {
2216+
if (recordingMode_) return 0;
21262217
// Read the accumulated GPU atomic counter (valid after flush/commitAndWait)
21272218
int64_t count = 0;
21282219
if (perturbCountPending_ && perturbCountBuf_) {
@@ -2137,6 +2228,43 @@ int64_t deferredPerturbCount() override {
21372228
return count;
21382229
}
21392230

2231+
// ============ Recording mode API ============
2232+
2233+
void beginRecording() override {
2234+
recordingMode_ = true;
2235+
recordedItems_.clear();
2236+
recordedFlushPoints_.clear();
2237+
recordingBatchCount_ = 0;
2238+
}
2239+
2240+
void endRecording() override {
2241+
// Flush any remaining batch
2242+
if (recordingBatchCount_ > 0) {
2243+
size_t startIdx = recordedItems_.size() - recordingBatchCount_;
2244+
recordedFlushPoints_.push_back({startIdx, recordingBatchCount_});
2245+
recordingBatchCount_ = 0;
2246+
}
2247+
2248+
recordingMode_ = false;
2249+
totalPrecomputedItems_ = recordedItems_.size();
2250+
2251+
if (totalPrecomputedItems_ > 0) {
2252+
// Upload all recorded items to device (single memcpy at init time).
2253+
// On Metal unified memory this is fast — just a CPU write to shared buffer.
2254+
size_t bytes = totalPrecomputedItems_ * sizeof(LUGemmWorkItem);
2255+
size_t int64sNeeded = (bytes + sizeof(int64_t) - 1) / sizeof(int64_t);
2256+
devPrecomputedItems_.resizeToAtLeast(int64sNeeded);
2257+
memcpy(devPrecomputedItems_.ptr(), recordedItems_.data(), bytes);
2258+
}
2259+
2260+
usePrecomputed_ = true;
2261+
precomputedFlushIdx_ = 0;
2262+
2263+
// Free host recording buffers (data is now on device)
2264+
recordedItems_.clear();
2265+
recordedItems_.shrink_to_fit();
2266+
}
2267+
21402268
MetalSymbolicCtx& sym;
21412269
int64_t numSpans_;
21422270
MetalMirror<float> tempBuffer;
@@ -2185,6 +2313,20 @@ int64_t deferredPerturbCount() override {
21852313
// Scratch buffer for two-phase deterministic sparse elimination
21862314
MetalMirror<float> elimScratchBuffer;
21872315

2316+
// ============ Recording mode for pre-computed GemmWorkItems ============
2317+
// During recording: capture all LUGemmWorkItems and flush boundaries.
2318+
// After endRecording(): dispatch from pre-computed device buffers.
2319+
// This eliminates per-lump CPU memcpy in flushPendingGemms.
2320+
bool recordingMode_ = false;
2321+
std::vector<LUGemmWorkItem> recordedItems_; // all items across all flushes
2322+
std::vector<std::pair<size_t, size_t>> recordedFlushPoints_; // (startIdx, count) per flush
2323+
size_t recordingBatchCount_ = 0; // items in current batch
2324+
2325+
bool usePrecomputed_ = false;
2326+
MetalMirror<int64_t> devPrecomputedItems_; // LUGemmWorkItems on device (as int64_t for MetalMirror)
2327+
size_t precomputedFlushIdx_ = 0; // current flush point index during dispatch
2328+
size_t totalPrecomputedItems_ = 0; // total items for bounds checking
2329+
21882330
};
21892331

21902332
// Solve context for float - Metal implementation

baspacho/baspacho/MetalKernels.metal

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,25 @@ kernel void lu_batchedSaveGemm_kernel_float(
11341134
}
11351135
}
11361136

1137+
// prepareAssemble: GPU kernel that replaces CPU loop + memcpy.
1138+
// Reads device-resident skeleton arrays (chainColPtr, chainRowSpan, chainData)
1139+
// and writes spanToChainOffset[chainRowSpan[i]] = chainData[i] for all chain
1140+
// entries of the target lump. One thread per chain entry.
1141+
kernel void prepareAssemble_kernel_float(
1142+
constant int64_t* chainColPtr [[buffer(0)]],
1143+
constant int64_t* chainRowSpan [[buffer(1)]],
1144+
constant int64_t* chainData [[buffer(2)]],
1145+
device int64_t* spanToChainOffset [[buffer(3)]],
1146+
constant int64_t& targetLump [[buffer(4)]],
1147+
uint tid [[thread_position_in_grid]])
1148+
{
1149+
int64_t start = chainColPtr[targetLump];
1150+
int64_t end = chainColPtr[targetLump + 1];
1151+
if (int64_t(tid) >= end - start) return;
1152+
int64_t i = start + int64_t(tid);
1153+
spanToChainOffset[chainRowSpan[i]] = chainData[i];
1154+
}
1155+
11371156
// saveGemm: C -= L * U (all row-major with strides)
11381157
kernel void lu_saveGemm_kernel_float(
11391158
constant float* L [[buffer(0)]],

0 commit comments

Comments
 (0)