Skip to content

Commit 19644d3

Browse files
committed
feat: Dual-path getrf — custom kernel for external encoder, MPS for normal mode
Replace the MPS-only getrf with a dual-path implementation: - External encoder mode: uses lu_getrf_kernel_float compute shader via encodeKernel(), compatible with external command buffer recording. Outputs int64_t pivots directly (no uint32→int64 conversion needed). - Normal mode: uses MPS MPSMatrixDecompositionLU for optimal performance on larger blocks (encodeToCommandBuffer is incompatible with external encoder but faster due to hardware-optimized implementation). This was the only remaining blocker for external encoder compatibility in the LU factorization hot path. All other operations (trsm, assemble, applyRowPerm, prepareAssemble, flushPendingGemms) already use encodeKernel() which routes correctly through external encoder. Performance (ring 47x47, 500 reps, M4 Pro): MPS baseline: factor 0.74ms, solve 0.58ms, total 1.33ms Dual-path: factor 0.74ms, solve 0.58ms, total 1.33ms Custom-only: factor 2.51ms (3.4x slower — expected for single-thread) Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
1 parent c9d03dc commit 19644d3

File tree

1 file changed

+141
-80
lines changed

1 file changed

+141
-80
lines changed

baspacho/baspacho/MatOpsMetal.mm

Lines changed: 141 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,101 +1786,162 @@ virtual int getrf(int64_t m, int64_t n, float* data, int64_t offA, int64_t* pivo
17861786
id<MTLBuffer> dataBuffer = (__bridge id<MTLBuffer>)bufferInfo.first;
17871787
size_t dataBaseOffset = bufferInfo.second;
17881788

1789-
// Ensure devPivots is large enough for this lump's pivots
1790-
devPivots.resizeToAtLeast(minMN);
1791-
1792-
// MPS LU factorization on GPU for all sizes.
17931789
// Flush pending saveGemm work items first — ensures all Schur
17941790
// complement updates are dispatched before factorization of this lump.
17951791
flushPendingGemms();
17961792

1797-
// End pending compute encoder (MPS needs its own encoding)
1798-
if (pendingEncoder_) {
1799-
[pendingEncoder_ endEncoding];
1800-
pendingEncoder_ = nil;
1793+
// External encoder mode: use custom compute kernel (encodeKernel-compatible).
1794+
// MPS encodeToCommandBuffer is incompatible with external encoder.
1795+
// Normal mode: use MPS for better performance on larger blocks.
1796+
if (sym.usingExternalEncoder) {
1797+
return getrfCustomKernel(m, n, minMN, dataBuffer, dataBaseOffset, offA, pivots);
1798+
} else {
1799+
return getrfMPS(m, n, minMN, dataBuffer, dataBaseOffset, offA, pivots);
18011800
}
1801+
}
1802+
}
18021803

1803-
// Ensure we have a command buffer
1804-
if (!pendingCmdBuf_) {
1805-
pendingCmdBuf_ = [sym.commandQueue commandBuffer];
1804+
// Custom compute kernel for LU factorization — external encoder compatible.
1805+
// Single-threaded sequential LU with partial pivoting on GPU.
1806+
// Outputs int64_t pivots directly — no uint32→int64 conversion needed.
1807+
int getrfCustomKernel(int64_t m, int64_t n, int64_t minMN,
1808+
id<MTLBuffer> dataBuffer, size_t dataBaseOffset,
1809+
int64_t offA, int64_t* pivots) {
1810+
id<MTLComputePipelineState> pipeline = getProfiledPipeline(
1811+
"lu_getrf_kernel_float");
1812+
1813+
// Compute absolute offset from buffer start (element offset)
1814+
int64_t absOffA = (int64_t)(dataBaseOffset / sizeof(float)) + offA;
1815+
1816+
// Determine pivot output buffer and offset
1817+
id<MTLBuffer> pivotBuffer;
1818+
size_t pivotByteOffset = 0;
1819+
if (devAllPivots.buffer()) {
1820+
// GPU-resident pivot path: write directly into persistent devAllPivots
1821+
if (!allPivotsCpuBase_) {
1822+
allPivotsCpuBase_ = pivots;
18061823
}
1824+
int64_t pivotElemOffset = pivots - allPivotsCpuBase_;
1825+
allPivotsCount_ = std::max(allPivotsCount_, pivotElemOffset + minMN);
1826+
pivotBuffer = (__bridge id<MTLBuffer>)devAllPivots.buffer();
1827+
pivotByteOffset = pivotElemOffset * sizeof(int64_t);
1828+
pivotsOnGpu_ = true;
1829+
} else {
1830+
devPivots.resizeToAtLeast(minMN);
1831+
pivotBuffer = (__bridge id<MTLBuffer>)devPivots.buffer();
1832+
pivotByteOffset = 0;
1833+
}
18071834

1808-
// Create MPSMatrix view for the block at data+offA (row-major, m×n, stride=n)
1809-
MPSMatrixDescriptor* descA = [MPSMatrixDescriptor
1810-
matrixDescriptorWithRows:m columns:n
1811-
rowBytes:n * sizeof(float) dataType:MPSDataTypeFloat32];
1812-
MPSMatrix* mpsA = [[MPSMatrix alloc]
1813-
initWithBuffer:dataBuffer
1814-
offset:dataBaseOffset + offA * sizeof(float)
1815-
descriptor:descA];
1835+
encodeKernel(
1836+
pipeline,
1837+
^(id<MTLComputeCommandEncoder> encoder) {
1838+
[encoder setBuffer:dataBuffer offset:0 atIndex:0];
1839+
[encoder setBytes:&absOffA length:sizeof(int64_t) atIndex:1];
1840+
[encoder setBytes:&m length:sizeof(int64_t) atIndex:2];
1841+
[encoder setBytes:&n length:sizeof(int64_t) atIndex:3];
1842+
[encoder setBuffer:pivotBuffer offset:pivotByteOffset atIndex:4];
1843+
},
1844+
1);
18161845

1817-
// Pivot buffer (UInt32 format required by MPS)
1818-
devPivotBuf32.resizeToAtLeast(minMN);
1819-
MPSMatrixDescriptor* descPiv = [MPSMatrixDescriptor
1820-
matrixDescriptorWithRows:1 columns:minMN
1821-
rowBytes:minMN * sizeof(uint32_t) dataType:MPSDataTypeUInt32];
1822-
MPSMatrix* mpsPiv = [[MPSMatrix alloc]
1823-
initWithBuffer:(__bridge id<MTLBuffer>)devPivotBuf32.buffer()
1824-
offset:0 descriptor:descPiv];
1825-
1826-
// Encode MPS LU factorization (in-place: resultMatrix = sourceMatrix)
1827-
MPSMatrixDecompositionLU* mpsLU = [[MPSMatrixDecompositionLU alloc]
1828-
initWithDevice:sym.device rows:m columns:n];
1829-
[mpsLU encodeToCommandBuffer:pendingCmdBuf_
1830-
sourceMatrix:mpsA resultMatrix:mpsA
1831-
pivotIndices:mpsPiv status:nil];
1832-
1833-
// When profiling, commit and wait to get per-getrf GPU timestamps
1834-
if (metalProfilingEnabled()) {
1835-
[pendingCmdBuf_ commit];
1836-
[pendingCmdBuf_ waitUntilCompleted];
1837-
double gpuTimeMs = ([pendingCmdBuf_ GPUEndTime] - [pendingCmdBuf_ GPUStartTime]) * 1000.0;
1838-
NSLog(@"[GPU] %-45s size=%lldx%lld gpu=%.3fms",
1839-
"MPS_LU_getrf", m, n, gpuTimeMs);
1840-
pendingCmdBuf_ = nil;
1846+
if (!devAllPivots.buffer()) {
1847+
// Non-general fallback: commit and read pivots back to CPU
1848+
commitPending();
1849+
waitForGpu();
1850+
int64_t* gpuPivots = devPivots.ptr();
1851+
for (int64_t i = 0; i < minMN; i++) {
1852+
pivots[i] = gpuPivots[i];
18411853
}
1854+
}
18421855

1843-
// GPU-resident pivot path: for general (LU) matrices with pre-allocated
1844-
// devAllPivots, encode a GPU-side uint32→int64 conversion kernel and
1845-
// keep pivots on GPU. For non-general matrices (e.g. simple LU tests),
1846-
// fall back to CPU conversion with commitAndWait.
1847-
if (devAllPivots.buffer()) {
1848-
// Compute offset into the persistent all-pivots buffer.
1849-
if (!allPivotsCpuBase_) {
1850-
allPivotsCpuBase_ = pivots; // First getrf call — record base
1851-
}
1852-
int64_t pivotOffset = pivots - allPivotsCpuBase_;
1853-
allPivotsCount_ = std::max(allPivotsCount_, pivotOffset + minMN);
1854-
1855-
// Encode GPU-side pivot conversion (uint32→int64) into the same cmd buffer.
1856-
int64_t pivotByteOffset = pivotOffset * sizeof(int64_t);
1857-
id<MTLComputePipelineState> convertPipeline = getProfiledPipeline(
1858-
"lu_convertPivots_kernel_float");
1859-
encodeKernel(
1860-
convertPipeline,
1861-
^(id<MTLComputeCommandEncoder> encoder) {
1862-
[encoder setBuffer:(__bridge id<MTLBuffer>)devPivotBuf32.buffer()
1863-
offset:0 atIndex:0];
1864-
[encoder setBuffer:(__bridge id<MTLBuffer>)devAllPivots.buffer()
1865-
offset:pivotByteOffset atIndex:1];
1866-
[encoder setBytes:&minMN length:sizeof(int64_t) atIndex:2];
1867-
},
1868-
(NSUInteger)minMN);
1856+
return 0;
1857+
}
18691858

1870-
// Mark pivots as GPU-resident — applyRowPerm will skip memcpy.
1871-
pivotsOnGpu_ = true;
1872-
} else {
1873-
// Fallback: non-general matrix, commit and read pivots on CPU
1874-
commitPending();
1875-
waitForGpu();
1876-
uint32_t* mpsPivots = devPivotBuf32.ptr();
1877-
for (int64_t i = 0; i < minMN; i++) {
1878-
pivots[i] = static_cast<int64_t>(mpsPivots[i]);
1879-
}
1859+
// MPS-based LU factorization — faster for normal mode but incompatible
1860+
// with external encoder (MPS requires encodeToCommandBuffer, not compute encoder).
1861+
int getrfMPS(int64_t m, int64_t n, int64_t minMN,
1862+
id<MTLBuffer> dataBuffer, size_t dataBaseOffset,
1863+
int64_t offA, int64_t* pivots) {
1864+
// End pending compute encoder (MPS needs its own encoding)
1865+
if (pendingEncoder_) {
1866+
[pendingEncoder_ endEncoding];
1867+
pendingEncoder_ = nil;
1868+
}
1869+
1870+
// Ensure we have a command buffer
1871+
if (!pendingCmdBuf_) {
1872+
pendingCmdBuf_ = [sym.commandQueue commandBuffer];
1873+
}
1874+
1875+
// Create MPSMatrix view for the block at data+offA (row-major, m×n, stride=n)
1876+
MPSMatrixDescriptor* descA = [MPSMatrixDescriptor
1877+
matrixDescriptorWithRows:m columns:n
1878+
rowBytes:n * sizeof(float) dataType:MPSDataTypeFloat32];
1879+
MPSMatrix* mpsA = [[MPSMatrix alloc]
1880+
initWithBuffer:dataBuffer
1881+
offset:dataBaseOffset + offA * sizeof(float)
1882+
descriptor:descA];
1883+
1884+
// Pivot buffer (UInt32 format required by MPS)
1885+
devPivotBuf32.resizeToAtLeast(minMN);
1886+
MPSMatrixDescriptor* descPiv = [MPSMatrixDescriptor
1887+
matrixDescriptorWithRows:1 columns:minMN
1888+
rowBytes:minMN * sizeof(uint32_t) dataType:MPSDataTypeUInt32];
1889+
MPSMatrix* mpsPiv = [[MPSMatrix alloc]
1890+
initWithBuffer:(__bridge id<MTLBuffer>)devPivotBuf32.buffer()
1891+
offset:0 descriptor:descPiv];
1892+
1893+
// Encode MPS LU factorization (in-place: resultMatrix = sourceMatrix)
1894+
MPSMatrixDecompositionLU* mpsLU = [[MPSMatrixDecompositionLU alloc]
1895+
initWithDevice:sym.device rows:m columns:n];
1896+
[mpsLU encodeToCommandBuffer:pendingCmdBuf_
1897+
sourceMatrix:mpsA resultMatrix:mpsA
1898+
pivotIndices:mpsPiv status:nil];
1899+
1900+
// When profiling, commit and wait to get per-getrf GPU timestamps
1901+
if (metalProfilingEnabled()) {
1902+
[pendingCmdBuf_ commit];
1903+
[pendingCmdBuf_ waitUntilCompleted];
1904+
double gpuTimeMs = ([pendingCmdBuf_ GPUEndTime] - [pendingCmdBuf_ GPUStartTime]) * 1000.0;
1905+
NSLog(@"[GPU] %-45s size=%lldx%lld gpu=%.3fms",
1906+
"MPS_LU_getrf", m, n, gpuTimeMs);
1907+
pendingCmdBuf_ = nil;
1908+
}
1909+
1910+
// GPU-resident pivot path: encode GPU-side uint32→int64 conversion
1911+
// and keep pivots on device. Non-general matrices fall back to CPU.
1912+
if (devAllPivots.buffer()) {
1913+
if (!allPivotsCpuBase_) {
1914+
allPivotsCpuBase_ = pivots;
18801915
}
1916+
int64_t pivotOffset = pivots - allPivotsCpuBase_;
1917+
allPivotsCount_ = std::max(allPivotsCount_, pivotOffset + minMN);
1918+
1919+
int64_t pivotByteOffset = pivotOffset * sizeof(int64_t);
1920+
id<MTLComputePipelineState> convertPipeline = getProfiledPipeline(
1921+
"lu_convertPivots_kernel_float");
1922+
encodeKernel(
1923+
convertPipeline,
1924+
^(id<MTLComputeCommandEncoder> encoder) {
1925+
[encoder setBuffer:(__bridge id<MTLBuffer>)devPivotBuf32.buffer()
1926+
offset:0 atIndex:0];
1927+
[encoder setBuffer:(__bridge id<MTLBuffer>)devAllPivots.buffer()
1928+
offset:pivotByteOffset atIndex:1];
1929+
[encoder setBytes:&minMN length:sizeof(int64_t) atIndex:2];
1930+
},
1931+
(NSUInteger)minMN);
18811932

1882-
return 0;
1933+
pivotsOnGpu_ = true;
1934+
} else {
1935+
// Fallback: non-general matrix, commit and read pivots on CPU
1936+
commitPending();
1937+
waitForGpu();
1938+
uint32_t* mpsPivots = devPivotBuf32.ptr();
1939+
for (int64_t i = 0; i < minMN; i++) {
1940+
pivots[i] = static_cast<int64_t>(mpsPivots[i]);
1941+
}
18831942
}
1943+
1944+
return 0;
18841945
}
18851946

18861947
virtual void trsmLowerUnit(int64_t m, int64_t n, const float* L, int64_t offL, float* B,

0 commit comments

Comments
 (0)