@@ -142,13 +142,18 @@ func.func @main() {
142142 %c4096 = arith.constant 4096 : index
143143 %c8 = arith.constant 8 : index
144144 %txcount = arith.constant 32768 : index
145+ %c24576 = arith.constant 24576 : index
146+ %c16384 = arith.constant 16384 : index
147+ %c49152 = arith.constant 49152 : index
148+ %c57344 = arith.constant 57344 : index
149+ %c40960 = arith.constant 40960 : index
145150
146151 %tidx = gpu.thread_id x
147152 %dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
148153 %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x128 x64 xf16 , 3 >
149154 %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x64 x128 xf16 ,3 >
150155 %rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x64 x128 xf16 ,3 > to memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
151-
156+ %dynsmem = gpu.dynamic_shared_memory : memref <?x i8 , #gpu.address_space < workgroup >>
152157 // Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
153158 %barrier = nvgpu.mbarrier.create -> !barrierType
154159
@@ -175,28 +180,25 @@ func.func @main() {
175180
176181 // Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
177182 %pipe1 = arith.constant 0 : index
178- %p1lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , 3 >
179- %p1rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
180- %p1halfFirst = memref.subview %p1rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
181- %p1halfSecond = memref.subview %p1rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 20480 >, 3 >
183+ %lhsSlice1 = memref.view %dynsmem [%c0 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
184+ %halfFirst1 = memref.view %dynsmem [%c32768 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
185+ %halfSecond1 = memref.view %dynsmem [%c40960 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
182186 nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe1 ], %txcount , predicate = %cnd : !barrierType
183187 %dim1 = arith.muli %pipe1 , %c64 : index
184- nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %p1lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , 3 >
185- nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %p1halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 16384 >, 3 >
186- nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %p1halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 20480 >, 3 >
188+ nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %lhsSlice1 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space < workgroup > >
189+ nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %halfFirst1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
190+ nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %halfSecond1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
187191
188192 // Step 5. [GPU] TMA Load Pipeline 2 (predicated)
189193 %pipe2 = arith.constant 1 : index
190- %p2lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
191- %p2rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
192- %p2halfFirst = memref.subview %p2rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
193- %p2halfSecond = memref.subview %p2rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
194+ %lhsSlice2 = memref.view %dynsmem [%c16384 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
195+ %halfFirst2 = memref.view %dynsmem [%c49152 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
196+ %halfSecond2 = memref.view %dynsmem [%c57344 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
194197 nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe2 ], %txcount , predicate = %cnd : !barrierType
195198 %dim2 = arith.muli %pipe2 , %c64 : index
196- nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %p2lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
197- nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %p2halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
198- nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %p2halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
199-
199+ nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %lhsSlice2 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space <workgroup >>
200+ nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %halfFirst2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
201+ nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %halfSecond2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
200202 // Step 6. [GPU] Initiliaze accumulator matrix
201203 %14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector <128 x128 xf32 >>
202204
@@ -282,4 +284,3 @@ func.func @main() {
282284 return
283285}
284286
285-
0 commit comments