5757
5858func.func private @printMemrefF32 (memref <*xf32 >)
5959
60- memref.global " private" @dynamicShmem : memref <0 xf16 , 3 > {alignment = 16 : i64 }
6160memref.global " private" @accShmem : memref <0 xf32 , 3 > {alignment = 16 : i64 }
6261
6362func.func @main () {
@@ -148,12 +147,11 @@ func.func @main() {
148147 %c57344 = arith.constant 57344 : index
149148 %c40960 = arith.constant 40960 : index
150149
151- %tidx = gpu.thread_id x
152- %dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
153- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x128 x64 xf16 , 3 >
154- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x64 x128 xf16 ,3 >
155- %rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x64 x128 xf16 ,3 > to memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
150+ %tidx = gpu.thread_id x
156151 %dynsmem = gpu.dynamic_shared_memory : memref <?xi8 , #gpu.address_space <workgroup >>
152+ %lhsShmem = memref.view %dynsmem [%c0 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <2 x128 x64 xf16 , #gpu.address_space <workgroup >>
153+ %rhsShmem = memref.view %dynsmem [%c32768 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <2 x64 x128 xf16 , #gpu.address_space <workgroup >>
154+
157155 // Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
158156 %barrier = nvgpu.mbarrier.create -> !barrierType
159157 %cnd = arith.cmpi eq , %tidx , %c0 : index
@@ -202,11 +200,11 @@ func.func @main() {
202200 // TMA wait
203201 %phase_c0 = arith.constant 0 : i1
204202 nvgpu.mbarrier.try_wait.parity %barrier [%i ], %phase_c0 , %ticks : !barrierType
205- %lhsSlice = memref.subview %lhsShmem [%i , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : ?>, 3 >
206- %rhsSlice = memref.subview %rhsShmem [%i , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[ 8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : ?>, 3 >
203+ %lhsSlice = memref.subview %lhsShmem [%i , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , #gpu.address_space < workgroup >> to memref <128 x64 xf16 , strided <[64 , 1 ], offset : ?>, #gpu.address_space < workgroup > >
204+ %rhsSlice = memref.subview %rhsShmem [%i , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , #gpu.address_space < workgroup > > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : ?>, #gpu.address_space < workgroup > >
207205 // Descriptor WGMMA
208- %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice , %descA : memref <128 x64 xf16 , strided <[64 , 1 ], offset : ?>, 3 >, !lhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <128 x64 xf16 , 3 >>
209- %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice , %descB : memref <64 x128 xf16 , strided <[128 , 1 ], offset : ?>, 3 >, !rhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <64 x128 xf16 , 3 >>
206+ %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice , %descA : memref <128 x64 xf16 , strided <[64 , 1 ], offset : ?>, #gpu.address_space < workgroup > >, !lhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <128 x64 xf16 , 3 >>
207+ %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice , %descB : memref <64 x128 xf16 , strided <[128 , 1 ], offset : ?>, #gpu.address_space < workgroup > >, !rhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <64 x128 xf16 , 3 >>
210208 // Perform WGMMA 128x128x64
211209 %md = nvgpu.warpgroup.mma %dA , %dB , %mc {transposeB } : <tensor = memref <128 x64 xf16 ,3 >>, <tensor = memref <64 x128 xf16 ,3 >>, <fragmented = vector <128 x128 xf32 >> -> <fragmented = vector <128 x128 xf32 >>
212210 scf.yield %md : !nvgpu.warpgroup.accumulator <fragmented = vector <128 x128 xf32 >>
@@ -271,7 +269,7 @@ func.func @main() {
271269 vector.print str " Correct Results :"
272270 vector.print %correctCount : i32
273271 vector.print str " Incorrect Results :"
274- vector.print %errorCount : i32
272+ vector.print %errorCount : i32
275273
276274 return
277- }
275+ }
0 commit comments