@@ -687,6 +687,7 @@ void collectDeferredElimPerturb() {
687687
688688 // Signal start of dense LU operations — wait for deferred sparse elim GPU work.
689689 void beginDenseOps (float * data, int64_t totalDataSize) override {
690+ if (recordingMode_) return ; // no-op during recording
690691 (void )data;
691692 (void )totalDataSize;
692693 waitForGpu ();
@@ -706,6 +707,46 @@ void commitAndWait() {
706707 // devGemmWorkBuf_, avoiding the need to wait for GPU to finish reading
707708 // previous data. Buffer reset happens at flush() time between factorizations.
708709 void flushPendingGemms () {
710+ if (recordingMode_) {
711+ // Recording mode: record flush point, skip GPU work
712+ if (recordingBatchCount_ > 0 ) {
713+ size_t startIdx = recordedItems_.size () - recordingBatchCount_;
714+ recordedFlushPoints_.push_back ({startIdx, recordingBatchCount_});
715+ recordingBatchCount_ = 0 ;
716+ }
717+ pendingGemms_.clear ();
718+ return ;
719+ }
720+
721+ if (usePrecomputed_) {
722+ // Pre-computed mode: dispatch from device-resident items
723+ if (precomputedFlushIdx_ >= recordedFlushPoints_.size ()) return ;
724+ auto [startIdx, count] = recordedFlushPoints_[precomputedFlushIdx_];
725+ precomputedFlushIdx_++;
726+ if (count == 0 ) return ;
727+
728+ // Byte offset into pre-computed buffer (MetalMirror backing is already aligned)
729+ size_t byteOffset = startIdx * sizeof (LUGemmWorkItem);
730+
731+ int64_t countI64 = (int64_t )count;
732+
733+ id <MTLComputePipelineState > pipeline = getProfiledPipeline (
734+ " lu_batchedSaveGemm_kernel_float" );
735+
736+ encodeKernel (
737+ pipeline,
738+ ^(id <MTLComputeCommandEncoder > encoder) {
739+ [encoder setBuffer: cachedDataBuffer_ offset: 0 atIndex: 0 ];
740+ [encoder setBuffer: (__bridge id <MTLBuffer >)devPrecomputedItems_.buffer ()
741+ offset: byteOffset
742+ atIndex: 1 ];
743+ [encoder setBytes: &countI64 length: sizeof (int64_t ) atIndex: 2 ];
744+ },
745+ (NSUInteger )count);
746+ pendingGemms_.clear ();
747+ return ;
748+ }
749+
709750 if (pendingGemms_.empty ()) return ;
710751
711752 int64_t count = (int64_t )pendingGemms_.size ();
@@ -756,6 +797,7 @@ void flushPendingGemms() {
756797 }
757798
758799 virtual void pseudoFactorSpans (float * data, int64_t spanBegin, int64_t spanEnd) override {
800+ if (recordingMode_) return ;
759801 @autoreleasepool {
760802 // Find the MTLBuffer for data
761803 auto bufferInfo = MetalBufferRegistry::instance ().findBuffer (data);
@@ -805,6 +847,7 @@ virtual void pseudoFactorSpans(float* data, int64_t spanBegin, int64_t spanEnd)
805847
806848 virtual void doElimination (const SymElimCtx& elimData, float * data, int64_t lumpsBegin,
807849 int64_t lumpsEnd) override {
850+ if (recordingMode_) return ;
808851 @autoreleasepool {
809852 const MetalSymElimCtx* pElim = dynamic_cast <const MetalSymElimCtx*>(&elimData);
810853 BASPACHO_CHECK_NOTNULL (pElim);
@@ -903,6 +946,7 @@ virtual void doElimination(const SymElimCtx& elimData, float* data, int64_t lump
903946 virtual void doEliminationLU (const SymElimCtx& elimData, float * data, int64_t lumpsBegin,
904947 int64_t lumpsEnd, float staticPivotThreshold,
905948 int64_t & perturbCount) override {
949+ if (recordingMode_) return ;
906950 @autoreleasepool {
907951 const MetalSymElimCtx* pElim = dynamic_cast <const MetalSymElimCtx*>(&elimData);
908952 BASPACHO_CHECK_NOTNULL (pElim);
@@ -1018,6 +1062,7 @@ void doAllEliminationsLU(const std::vector<SymElimCtxPtr>& elimCtxs,
10181062 const std::vector<int64_t >& ranges, float * data,
10191063 float staticPivotThreshold,
10201064 int64_t & totalPerturbCount) override {
1065+ if (recordingMode_) return ;
10211066 @autoreleasepool {
10221067 auto bufferInfo = MetalBufferRegistry::instance ().findBuffer (data);
10231068 if (!bufferInfo.first ) {
@@ -1160,6 +1205,7 @@ void doAllEliminationsLU(const std::vector<SymElimCtxPtr>& elimCtxs,
11601205 // Phase 2: segmented sum per target in fixed order (deterministic)
11611206 void doAllEliminations (const std::vector<SymElimCtxPtr>& elimCtxs,
11621207 const std::vector<int64_t >& ranges, float * data) override {
1208+ if (recordingMode_) return ;
11631209 @autoreleasepool {
11641210 auto bufferInfo = MetalBufferRegistry::instance ().findBuffer (data);
11651211 if (!bufferInfo.first ) {
@@ -1286,6 +1332,7 @@ void doAllEliminations(const std::vector<SymElimCtxPtr>& elimCtxs,
12861332
12871333 virtual double maxAbsDiag (const float * data, const int64_t * lumpStart, const int64_t * chainColPtr,
12881334 const int64_t * chainData, int64_t startLump, int64_t upToLump) override {
1335+ if (recordingMode_) return 0.0 ;
12891336 @autoreleasepool {
12901337 int64_t numLumps = upToLump - startLump;
12911338 if (numLumps <= 0 ) return 0.0 ;
@@ -1332,6 +1379,7 @@ virtual double maxAbsDiag(const float* data, const int64_t* lumpStart, const int
13321379 }
13331380
13341381 virtual void potrf (int64_t n, float * data, int64_t offA) override {
1382+ if (recordingMode_) return ;
13351383 @autoreleasepool {
13361384 if (n <= 0 ) return ;
13371385
@@ -1387,6 +1435,7 @@ virtual void potrf(int64_t n, float* data, int64_t offA) override {
13871435 }
13881436
13891437 virtual void trsm (int64_t n, int64_t k, float * data, int64_t offA, int64_t offB) override {
1438+ if (recordingMode_) return ;
13901439 @autoreleasepool {
13911440 if (n <= 0 || k <= 0 ) return ;
13921441
@@ -1449,6 +1498,7 @@ virtual void trsm(int64_t n, int64_t k, float* data, int64_t offA, int64_t offB)
14491498
14501499 virtual void saveSyrkGemm (int64_t m, int64_t n, int64_t k, const float * data,
14511500 int64_t offset) override {
1501+ if (recordingMode_) return ;
14521502 @autoreleasepool {
14531503 if (m <= 0 || n <= 0 || k <= 0 ) return ;
14541504
@@ -1528,6 +1578,8 @@ virtual void saveSyrkGemm(int64_t m, int64_t n, int64_t k, const float* data,
15281578 }
15291579
15301580 virtual void prepareAssemble (int64_t targetLump) override {
1581+ if (recordingMode_) return ; // no-op during recording
1582+
15311583 // Only flush if assemble() was actually called since last prepareAssemble.
15321584 // For LU factorization (isGeneral()==true), eliminateBoardLU only calls
15331585 // saveGemm — never assemble — so flushing is unnecessary and avoiding it
@@ -1537,22 +1589,36 @@ virtual void prepareAssemble(int64_t targetLump) override {
15371589 assembleWasCalled_ = false ;
15381590 }
15391591
1540- // Prepare chain offset mapping for assembly (same as CUDA version)
1592+ // GPU kernel: reads device-resident skeleton arrays, writes devSpanToChainOffset.
1593+ // All inputs (devChainColPtr, devChainRowSpan, devChainData) are already on device
1594+ // in MetalSymbolicCtx. No CPU loop, no memcpy — eliminates CPU→GPU sync point.
15411595 const CoalescedBlockMatrixSkel& skel = sym.skel ;
1596+ int64_t numEntries = skel.chainColPtr [targetLump + 1 ] - skel.chainColPtr [targetLump];
1597+ if (numEntries <= 0 ) return ;
15421598
1543- for (int64_t i = skel.chainColPtr [targetLump], iEnd = skel.chainColPtr [targetLump + 1 ]; i < iEnd;
1544- i++) {
1545- spanToChainOffset[skel.chainRowSpan[i]] = skel.chainData [i];
1546- }
1599+ id <MTLComputePipelineState > pipeline = getProfiledPipeline (
1600+ " prepareAssemble_kernel_float" );
15471601
1548- // Copy to device
1549- memcpy (devSpanToChainOffset.ptr (), spanToChainOffset.data (),
1550- spanToChainOffset.size () * sizeof (int64_t ));
1602+ encodeKernel (
1603+ pipeline,
1604+ ^(id <MTLComputeCommandEncoder > encoder) {
1605+ [encoder setBuffer: (__bridge id <MTLBuffer >)sym.devChainColPtr.buffer ()
1606+ offset: 0 atIndex: 0 ];
1607+ [encoder setBuffer: (__bridge id <MTLBuffer >)sym.devChainRowSpan.buffer ()
1608+ offset: 0 atIndex: 1 ];
1609+ [encoder setBuffer: (__bridge id <MTLBuffer >)sym.devChainData.buffer ()
1610+ offset: 0 atIndex: 2 ];
1611+ [encoder setBuffer: (__bridge id <MTLBuffer >)devSpanToChainOffset.buffer ()
1612+ offset: 0 atIndex: 3 ];
1613+ [encoder setBytes: &targetLump length: sizeof (int64_t ) atIndex: 4 ];
1614+ },
1615+ (NSUInteger )numEntries);
15511616 }
15521617
15531618 virtual void assemble (float * data, int64_t rectRowBegin, int64_t dstStride,
15541619 int64_t srcColDataOffset, int64_t srcRectWidth, int64_t numBlockRows,
15551620 int64_t numBlockCols) override {
1621+ if (recordingMode_) return ;
15561622 @autoreleasepool {
15571623 if (numBlockRows <= 0 || numBlockCols <= 0 ) return ;
15581624 assembleWasCalled_ = true ;
@@ -1636,6 +1702,7 @@ virtual void assemble(float* data, int64_t rectRowBegin, int64_t dstStride,
16361702 // CPU path for dense ops, GPU path for deferred execution.
16371703 virtual int64_t perturbSmallDiagonals (int64_t n, float * data, int64_t offset, int64_t stride,
16381704 float threshold) override {
1705+ if (recordingMode_) return 0 ;
16391706 @autoreleasepool {
16401707 if (n <= 0 ) return 0 ;
16411708
@@ -1695,6 +1762,7 @@ virtual int64_t perturbSmallDiagonals(int64_t n, float* data, int64_t offset, in
16951762 // ============ LU factorization methods ============
16961763
16971764 virtual int getrf (int64_t m, int64_t n, float * data, int64_t offA, int64_t * pivots) override {
1765+ if (recordingMode_) { flushPendingGemms (); return 0 ; }
16981766 @autoreleasepool {
16991767 if (m <= 0 || n <= 0 ) return 0 ;
17001768
@@ -1807,6 +1875,7 @@ virtual int getrf(int64_t m, int64_t n, float* data, int64_t offA, int64_t* pivo
18071875
18081876 virtual void trsmLowerUnit (int64_t m, int64_t n, const float * L, int64_t offL, float * B,
18091877 int64_t offB, int64_t ldb) override {
1878+ if (recordingMode_) return ;
18101879 @autoreleasepool {
18111880 if (m <= 0 || n <= 0 ) return ;
18121881
@@ -1857,6 +1926,7 @@ virtual void trsmLowerUnit(int64_t m, int64_t n, const float* L, int64_t offL, f
18571926
18581927 virtual void trsmUpperRight (int64_t m, int64_t n, const float * U, int64_t offU, float * B,
18591928 int64_t offB, int64_t ldb) override {
1929+ if (recordingMode_) return ;
18601930 @autoreleasepool {
18611931 if (m <= 0 || n <= 0 ) return ;
18621932
@@ -1977,13 +2047,26 @@ virtual void saveGemm(int64_t m, int64_t n, int64_t k, const float* L, int64_t o
19772047 item.n = n;
19782048 item.k = k;
19792049
2050+ if (recordingMode_) {
2051+ recordedItems_.push_back (item);
2052+ recordingBatchCount_++;
2053+ sym.luGemmCalls ++;
2054+ return ;
2055+ }
2056+ if (usePrecomputed_) {
2057+ // Items already on device — dispatched from pre-computed buffer in flushPendingGemms
2058+ sym.luGemmCalls ++;
2059+ return ;
2060+ }
2061+
19802062 pendingGemms_.push_back (item);
19812063 sym.luGemmCalls ++;
19822064 }
19832065 }
19842066
19852067 virtual void applyRowPerm (int64_t * pivots, int64_t n, float * data, int64_t offData, int64_t ld,
19862068 int64_t numCols) override {
2069+ if (recordingMode_) return ;
19872070 @autoreleasepool {
19882071 if (n <= 0 || numCols <= 0 ) return ;
19892072
@@ -2048,6 +2131,8 @@ virtual void applyRowPerm(int64_t* pivots, int64_t n, float* data, int64_t offDa
20482131 }
20492132
20502133 void flush () override {
2134+ flushPendingGemms ();
2135+ if (recordingMode_) return ; // skip pivot copies during recording
20512136 commitAndWait ();
20522137
20532138 // Reset batched state for next factorization
@@ -2078,8 +2163,12 @@ void reset() override {
20782163 deferredElimPerturbBuf_ = nil ;
20792164 assembleWasCalled_ = false ;
20802165 potrfStatusPending_ = false ;
2166+ // Reset pre-computed dispatch index (buffers persist across calls)
2167+ if (usePrecomputed_) {
2168+ precomputedFlushIdx_ = 0 ;
2169+ }
20812170 // Buffers (tempBuffer, devSpanToChainOffset, devPivots, devAllPivots,
2082- // devGemmWorkBuf_, perturbCountBuf_) are NOT freed — reused across calls.
2171+ // devGemmWorkBuf_, perturbCountBuf_, devPrecomputedItems_ ) are NOT freed — reused across calls.
20832172 }
20842173
20852174 // Pre-allocate all Metal buffers to max needed sizes so no allocation occurs
@@ -2112,6 +2201,7 @@ void preAllocateForLU(int64_t maxDenseBlockSize, int64_t totalDensePivots) overr
21122201 // backing stores — both are CPU-accessible, so no D->H->D round-trip.
21132202 void flushDevicePivots (int64_t * devDstPivots) override {
21142203 flushPendingGemms ();
2204+ if (recordingMode_) return ;
21152205 if (pivotsOnGpu_ && allPivotsCount_ > 0 ) {
21162206 // On Metal unified memory, devAllPivots.ptr() and devDstPivots are both
21172207 // CPU-accessible pointers to Metal buffer backing stores.
@@ -2123,6 +2213,7 @@ void flushDevicePivots(int64_t* devDstPivots) override {
21232213 }
21242214
21252215 int64_t deferredPerturbCount () override {
2216+ if (recordingMode_) return 0 ;
21262217 // Read the accumulated GPU atomic counter (valid after flush/commitAndWait)
21272218 int64_t count = 0 ;
21282219 if (perturbCountPending_ && perturbCountBuf_) {
@@ -2137,6 +2228,43 @@ int64_t deferredPerturbCount() override {
21372228 return count;
21382229 }
21392230
2231+ // ============ Recording mode API ============
2232+
2233+ void beginRecording () override {
2234+ recordingMode_ = true ;
2235+ recordedItems_.clear ();
2236+ recordedFlushPoints_.clear ();
2237+ recordingBatchCount_ = 0 ;
2238+ }
2239+
2240+ void endRecording () override {
2241+ // Flush any remaining batch
2242+ if (recordingBatchCount_ > 0 ) {
2243+ size_t startIdx = recordedItems_.size () - recordingBatchCount_;
2244+ recordedFlushPoints_.push_back ({startIdx, recordingBatchCount_});
2245+ recordingBatchCount_ = 0 ;
2246+ }
2247+
2248+ recordingMode_ = false ;
2249+ totalPrecomputedItems_ = recordedItems_.size ();
2250+
2251+ if (totalPrecomputedItems_ > 0 ) {
2252+ // Upload all recorded items to device (single memcpy at init time).
2253+ // On Metal unified memory this is fast — just a CPU write to shared buffer.
2254+ size_t bytes = totalPrecomputedItems_ * sizeof (LUGemmWorkItem);
2255+ size_t int64sNeeded = (bytes + sizeof (int64_t ) - 1 ) / sizeof (int64_t );
2256+ devPrecomputedItems_.resizeToAtLeast (int64sNeeded);
2257+ memcpy (devPrecomputedItems_.ptr (), recordedItems_.data (), bytes);
2258+ }
2259+
2260+ usePrecomputed_ = true ;
2261+ precomputedFlushIdx_ = 0 ;
2262+
2263+ // Free host recording buffers (data is now on device)
2264+ recordedItems_.clear ();
2265+ recordedItems_.shrink_to_fit ();
2266+ }
2267+
21402268 MetalSymbolicCtx& sym;
21412269 int64_t numSpans_;
21422270 MetalMirror<float > tempBuffer;
@@ -2185,6 +2313,20 @@ int64_t deferredPerturbCount() override {
21852313 // Scratch buffer for two-phase deterministic sparse elimination
21862314 MetalMirror<float > elimScratchBuffer;
21872315
2316+ // ============ Recording mode for pre-computed GemmWorkItems ============
2317+ // During recording: capture all LUGemmWorkItems and flush boundaries.
2318+ // After endRecording(): dispatch from pre-computed device buffers.
2319+ // This eliminates per-lump CPU memcpy in flushPendingGemms.
2320+ bool recordingMode_ = false ;
2321+ std::vector<LUGemmWorkItem> recordedItems_; // all items across all flushes
2322+ std::vector<std::pair<size_t , size_t >> recordedFlushPoints_; // (startIdx, count) per flush
2323+ size_t recordingBatchCount_ = 0 ; // items in current batch
2324+
2325+ bool usePrecomputed_ = false ;
2326+ MetalMirror<int64_t > devPrecomputedItems_; // LUGemmWorkItems on device (as int64_t for MetalMirror)
2327+ size_t precomputedFlushIdx_ = 0 ; // current flush point index during dispatch
2328+ size_t totalPrecomputedItems_ = 0 ; // total items for bounds checking
2329+
21882330};
21892331
21902332// Solve context for float - Metal implementation
0 commit comments