@@ -1786,101 +1786,162 @@ virtual int getrf(int64_t m, int64_t n, float* data, int64_t offA, int64_t* pivo
17861786 id <MTLBuffer > dataBuffer = (__bridge id <MTLBuffer >)bufferInfo.first ;
17871787 size_t dataBaseOffset = bufferInfo.second ;
17881788
1789- // Ensure devPivots is large enough for this lump's pivots
1790- devPivots.resizeToAtLeast (minMN);
1791-
1792- // MPS LU factorization on GPU for all sizes.
17931789 // Flush pending saveGemm work items first — ensures all Schur
17941790 // complement updates are dispatched before factorization of this lump.
17951791 flushPendingGemms ();
17961792
1797- // End pending compute encoder (MPS needs its own encoding)
1798- if (pendingEncoder_) {
1799- [pendingEncoder_ endEncoding ];
1800- pendingEncoder_ = nil ;
1793+ // External encoder mode: use custom compute kernel (encodeKernel-compatible).
1794+ // MPS encodeToCommandBuffer is incompatible with external encoder.
1795+ // Normal mode: use MPS for better performance on larger blocks.
1796+ if (sym.usingExternalEncoder ) {
1797+ return getrfCustomKernel (m, n, minMN, dataBuffer, dataBaseOffset, offA, pivots);
1798+ } else {
1799+ return getrfMPS (m, n, minMN, dataBuffer, dataBaseOffset, offA, pivots);
18011800 }
1801+ }
1802+ }
18021803
1803- // Ensure we have a command buffer
1804- if (!pendingCmdBuf_) {
1805- pendingCmdBuf_ = [sym.commandQueue commandBuffer ];
1804+ // Custom compute kernel for LU factorization — external encoder compatible.
1805+ // Single-threaded sequential LU with partial pivoting on GPU.
1806+ // Outputs int64_t pivots directly — no uint32→int64 conversion needed.
1807+ int getrfCustomKernel (int64_t m, int64_t n, int64_t minMN,
1808+ id <MTLBuffer > dataBuffer, size_t dataBaseOffset,
1809+ int64_t offA, int64_t * pivots) {
1810+ id <MTLComputePipelineState > pipeline = getProfiledPipeline (
1811+ " lu_getrf_kernel_float" );
1812+
1813+ // Compute absolute offset from buffer start (element offset)
1814+ int64_t absOffA = (int64_t )(dataBaseOffset / sizeof (float )) + offA;
1815+
1816+ // Determine pivot output buffer and offset
1817+ id <MTLBuffer > pivotBuffer;
1818+ size_t pivotByteOffset = 0 ;
1819+ if (devAllPivots.buffer ()) {
1820+ // GPU-resident pivot path: write directly into persistent devAllPivots
1821+ if (!allPivotsCpuBase_) {
1822+ allPivotsCpuBase_ = pivots;
18061823 }
1824+ int64_t pivotElemOffset = pivots - allPivotsCpuBase_;
1825+ allPivotsCount_ = std::max (allPivotsCount_, pivotElemOffset + minMN);
1826+ pivotBuffer = (__bridge id <MTLBuffer >)devAllPivots.buffer ();
1827+ pivotByteOffset = pivotElemOffset * sizeof (int64_t );
1828+ pivotsOnGpu_ = true ;
1829+ } else {
1830+ devPivots.resizeToAtLeast (minMN);
1831+ pivotBuffer = (__bridge id <MTLBuffer >)devPivots.buffer ();
1832+ pivotByteOffset = 0 ;
1833+ }
18071834
1808- // Create MPSMatrix view for the block at data+offA (row-major, m×n, stride=n)
1809- MPSMatrixDescriptor* descA = [MPSMatrixDescriptor
1810- matrixDescriptorWithRows: m columns: n
1811- rowBytes: n * sizeof (float ) dataType: MPSDataTypeFloat32];
1812- MPSMatrix* mpsA = [[MPSMatrix alloc ]
1813- initWithBuffer: dataBuffer
1814- offset: dataBaseOffset + offA * sizeof (float )
1815- descriptor: descA];
1835+ encodeKernel (
1836+ pipeline,
1837+ ^(id <MTLComputeCommandEncoder > encoder) {
1838+ [encoder setBuffer: dataBuffer offset: 0 atIndex: 0 ];
1839+ [encoder setBytes: &absOffA length: sizeof (int64_t ) atIndex: 1 ];
1840+ [encoder setBytes: &m length: sizeof (int64_t ) atIndex: 2 ];
1841+ [encoder setBytes: &n length: sizeof (int64_t ) atIndex: 3 ];
1842+ [encoder setBuffer: pivotBuffer offset: pivotByteOffset atIndex: 4 ];
1843+ },
1844+ 1 );
18161845
1817- // Pivot buffer (UInt32 format required by MPS)
1818- devPivotBuf32.resizeToAtLeast (minMN);
1819- MPSMatrixDescriptor* descPiv = [MPSMatrixDescriptor
1820- matrixDescriptorWithRows: 1 columns: minMN
1821- rowBytes: minMN * sizeof (uint32_t ) dataType: MPSDataTypeUInt32];
1822- MPSMatrix* mpsPiv = [[MPSMatrix alloc ]
1823- initWithBuffer: (__bridge id <MTLBuffer >)devPivotBuf32.buffer ()
1824- offset: 0 descriptor: descPiv];
1825-
1826- // Encode MPS LU factorization (in-place: resultMatrix = sourceMatrix)
1827- MPSMatrixDecompositionLU* mpsLU = [[MPSMatrixDecompositionLU alloc ]
1828- initWithDevice: sym.device rows: m columns: n];
1829- [mpsLU encodeToCommandBuffer: pendingCmdBuf_
1830- sourceMatrix: mpsA resultMatrix: mpsA
1831- pivotIndices: mpsPiv status: nil ];
1832-
1833- // When profiling, commit and wait to get per-getrf GPU timestamps
1834- if (metalProfilingEnabled ()) {
1835- [pendingCmdBuf_ commit ];
1836- [pendingCmdBuf_ waitUntilCompleted ];
1837- double gpuTimeMs = ([pendingCmdBuf_ GPUEndTime ] - [pendingCmdBuf_ GPUStartTime ]) * 1000.0 ;
1838- NSLog (@" [GPU] %-45s size=%lld x%lld gpu=%.3f ms" ,
1839- " MPS_LU_getrf" , m, n, gpuTimeMs);
1840- pendingCmdBuf_ = nil ;
1846+ if (!devAllPivots.buffer ()) {
1847+ // Non-general fallback: commit and read pivots back to CPU
1848+ commitPending ();
1849+ waitForGpu ();
1850+ int64_t * gpuPivots = devPivots.ptr ();
1851+ for (int64_t i = 0 ; i < minMN; i++) {
1852+ pivots[i] = gpuPivots[i];
18411853 }
1854+ }
18421855
1843- // GPU-resident pivot path: for general (LU) matrices with pre-allocated
1844- // devAllPivots, encode a GPU-side uint32→int64 conversion kernel and
1845- // keep pivots on GPU. For non-general matrices (e.g. simple LU tests),
1846- // fall back to CPU conversion with commitAndWait.
1847- if (devAllPivots.buffer ()) {
1848- // Compute offset into the persistent all-pivots buffer.
1849- if (!allPivotsCpuBase_) {
1850- allPivotsCpuBase_ = pivots; // First getrf call — record base
1851- }
1852- int64_t pivotOffset = pivots - allPivotsCpuBase_;
1853- allPivotsCount_ = std::max (allPivotsCount_, pivotOffset + minMN);
1854-
1855- // Encode GPU-side pivot conversion (uint32→int64) into the same cmd buffer.
1856- int64_t pivotByteOffset = pivotOffset * sizeof (int64_t );
1857- id <MTLComputePipelineState > convertPipeline = getProfiledPipeline (
1858- " lu_convertPivots_kernel_float" );
1859- encodeKernel (
1860- convertPipeline,
1861- ^(id <MTLComputeCommandEncoder > encoder) {
1862- [encoder setBuffer: (__bridge id <MTLBuffer >)devPivotBuf32.buffer ()
1863- offset: 0 atIndex: 0 ];
1864- [encoder setBuffer: (__bridge id <MTLBuffer >)devAllPivots.buffer ()
1865- offset: pivotByteOffset atIndex: 1 ];
1866- [encoder setBytes: &minMN length: sizeof (int64_t ) atIndex: 2 ];
1867- },
1868- (NSUInteger )minMN);
1856+ return 0 ;
1857+ }
18691858
1870- // Mark pivots as GPU-resident — applyRowPerm will skip memcpy.
1871- pivotsOnGpu_ = true ;
1872- } else {
1873- // Fallback: non-general matrix, commit and read pivots on CPU
1874- commitPending ();
1875- waitForGpu ();
1876- uint32_t * mpsPivots = devPivotBuf32.ptr ();
1877- for (int64_t i = 0 ; i < minMN; i++) {
1878- pivots[i] = static_cast <int64_t >(mpsPivots[i]);
1879- }
1859+ // MPS-based LU factorization — faster for normal mode but incompatible
1860+ // with external encoder (MPS requires encodeToCommandBuffer, not compute encoder).
1861+ int getrfMPS (int64_t m, int64_t n, int64_t minMN,
1862+ id <MTLBuffer > dataBuffer, size_t dataBaseOffset,
1863+ int64_t offA, int64_t * pivots) {
1864+ // End pending compute encoder (MPS needs its own encoding)
1865+ if (pendingEncoder_) {
1866+ [pendingEncoder_ endEncoding ];
1867+ pendingEncoder_ = nil ;
1868+ }
1869+
1870+ // Ensure we have a command buffer
1871+ if (!pendingCmdBuf_) {
1872+ pendingCmdBuf_ = [sym.commandQueue commandBuffer ];
1873+ }
1874+
1875+ // Create MPSMatrix view for the block at data+offA (row-major, m×n, stride=n)
1876+ MPSMatrixDescriptor* descA = [MPSMatrixDescriptor
1877+ matrixDescriptorWithRows: m columns: n
1878+ rowBytes: n * sizeof (float ) dataType: MPSDataTypeFloat32];
1879+ MPSMatrix* mpsA = [[MPSMatrix alloc ]
1880+ initWithBuffer: dataBuffer
1881+ offset: dataBaseOffset + offA * sizeof (float )
1882+ descriptor: descA];
1883+
1884+ // Pivot buffer (UInt32 format required by MPS)
1885+ devPivotBuf32.resizeToAtLeast (minMN);
1886+ MPSMatrixDescriptor* descPiv = [MPSMatrixDescriptor
1887+ matrixDescriptorWithRows: 1 columns: minMN
1888+ rowBytes: minMN * sizeof (uint32_t ) dataType: MPSDataTypeUInt32];
1889+ MPSMatrix* mpsPiv = [[MPSMatrix alloc ]
1890+ initWithBuffer: (__bridge id <MTLBuffer >)devPivotBuf32.buffer ()
1891+ offset: 0 descriptor: descPiv];
1892+
1893+ // Encode MPS LU factorization (in-place: resultMatrix = sourceMatrix)
1894+ MPSMatrixDecompositionLU* mpsLU = [[MPSMatrixDecompositionLU alloc ]
1895+ initWithDevice: sym.device rows: m columns: n];
1896+ [mpsLU encodeToCommandBuffer: pendingCmdBuf_
1897+ sourceMatrix: mpsA resultMatrix: mpsA
1898+ pivotIndices: mpsPiv status: nil ];
1899+
1900+ // When profiling, commit and wait to get per-getrf GPU timestamps
1901+ if (metalProfilingEnabled ()) {
1902+ [pendingCmdBuf_ commit ];
1903+ [pendingCmdBuf_ waitUntilCompleted ];
1904+ double gpuTimeMs = ([pendingCmdBuf_ GPUEndTime ] - [pendingCmdBuf_ GPUStartTime ]) * 1000.0 ;
1905+ NSLog (@" [GPU] %-45s size=%lld x%lld gpu=%.3f ms" ,
1906+ " MPS_LU_getrf" , m, n, gpuTimeMs);
1907+ pendingCmdBuf_ = nil ;
1908+ }
1909+
1910+ // GPU-resident pivot path: encode GPU-side uint32→int64 conversion
1911+ // and keep pivots on device. Non-general matrices fall back to CPU.
1912+ if (devAllPivots.buffer ()) {
1913+ if (!allPivotsCpuBase_) {
1914+ allPivotsCpuBase_ = pivots;
18801915 }
1916+ int64_t pivotOffset = pivots - allPivotsCpuBase_;
1917+ allPivotsCount_ = std::max (allPivotsCount_, pivotOffset + minMN);
1918+
1919+ int64_t pivotByteOffset = pivotOffset * sizeof (int64_t );
1920+ id <MTLComputePipelineState > convertPipeline = getProfiledPipeline (
1921+ " lu_convertPivots_kernel_float" );
1922+ encodeKernel (
1923+ convertPipeline,
1924+ ^(id <MTLComputeCommandEncoder > encoder) {
1925+ [encoder setBuffer: (__bridge id <MTLBuffer >)devPivotBuf32.buffer ()
1926+ offset: 0 atIndex: 0 ];
1927+ [encoder setBuffer: (__bridge id <MTLBuffer >)devAllPivots.buffer ()
1928+ offset: pivotByteOffset atIndex: 1 ];
1929+ [encoder setBytes: &minMN length: sizeof (int64_t ) atIndex: 2 ];
1930+ },
1931+ (NSUInteger )minMN);
18811932
1882- return 0 ;
1933+ pivotsOnGpu_ = true ;
1934+ } else {
1935+ // Fallback: non-general matrix, commit and read pivots on CPU
1936+ commitPending ();
1937+ waitForGpu ();
1938+ uint32_t * mpsPivots = devPivotBuf32.ptr ();
1939+ for (int64_t i = 0 ; i < minMN; i++) {
1940+ pivots[i] = static_cast <int64_t >(mpsPivots[i]);
1941+ }
18831942 }
1943+
1944+ return 0 ;
18841945 }
18851946
18861947 virtual void trsmLowerUnit (int64_t m, int64_t n, const float * L, int64_t offL, float * B,
0 commit comments