@@ -142,13 +142,17 @@ func.func @main() {
142142 %c4096 = arith.constant 4096 : index
143143 %c8 = arith.constant 8 : index
144144 %txcount = arith.constant 32768 : index
145+ %c24576 = arith.constant 24576 : index
146+ %c16384 = arith.constant 16384 : index
147+ %c49152 = arith.constant 49152 : index
148+ %c57344 = arith.constant 57344 : index
145149
146150 %tidx = gpu.thread_id x
147151 %dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
148152 %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x128 x64 xf16 , 3 >
149153 %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x64 x128 xf16 ,3 >
150154 %rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x64 x128 xf16 ,3 > to memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
151-
155+ %dynsmem = gpu.dynamic_shared_memory : memref <?x i8 , #gpu.address_space < workgroup >>
152156 // Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
153157 %barrier = nvgpu.mbarrier.create -> !barrierType
154158
@@ -175,28 +179,25 @@ func.func @main() {
175179
176180 // Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
177181 %pipe1 = arith.constant 0 : index
178- %p1lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , 3 >
179- %p1rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
180- %p1halfFirst = memref.subview %p1rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
181- %p1halfSecond = memref.subview %p1rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 20480 >, 3 >
182+ %lhsSlice1 = memref.view %dynsmem [%c0 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
183+ %halfFirst1 = memref.view %dynsmem [%c16384 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
184+ %halfSecond1 = memref.view %dynsmem [%c24576 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
182185 nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe1 ], %txcount , predicate = %cnd : !barrierType
183186 %dim1 = arith.muli %pipe1 , %c64 : index
184- nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %p1lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , 3 >
185- nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %p1halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 16384 >, 3 >
186- nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %p1halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 20480 >, 3 >
187+ nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %lhsSlice1 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space < workgroup > >
188+ nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %halfFirst1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
189+ nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %halfSecond1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
187190
188191 // Step 5. [GPU] TMA Load Pipeline 2 (predicated)
189192 %pipe2 = arith.constant 1 : index
190- %p2lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
191- %p2rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
192- %p2halfFirst = memref.subview %p2rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
193- %p2halfSecond = memref.subview %p2rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
193+ %lhsSlice2 = memref.view %dynsmem [%c32768 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
194+ %halfFirst2 = memref.view %dynsmem [%c49152 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
195+ %halfSecond2 = memref.view %dynsmem [%c57344 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
194196 nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe2 ], %txcount , predicate = %cnd : !barrierType
195197 %dim2 = arith.muli %pipe2 , %c64 : index
196- nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %p2lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
197- nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %p2halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
198- nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %p2halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
199-
198+ nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %lhsSlice2 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space <workgroup >>
199+ nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %halfFirst2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
200+ nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %halfSecond2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
200201 // Step 6. [GPU] Initiliaze accumulator matrix
201202 %14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector <128 x128 xf32 >>
202203
0 commit comments