3535
3636!rhs = memref <64 x128 xf16 >
3737!shmemrhs = memref <64 x128 xf16 , 3 >
38- !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = !shmemrhs , swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
38+ !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref < 64 x 64 x f16 , 3 > , swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
3939
4040module @mymod {
4141 func.func private @printMemrefF32 (memref <*xf32 >)
@@ -99,7 +99,8 @@ module @mymod {
9999 %6 = gpu.thread_id x
100100 %lhsShmem = memref.get_global @bufferLhsGlobal : !shmemlhs
101101 %rhsShmem = memref.get_global @bufferRhsGlobal : !shmemrhs
102- %rhsShmem2 = memref.subview %rhsShmem [32 , 0 ][128 , 64 ][1 , 1 ] : !shmemrhs to memref <128 x64 xf16 , strided <[128 , 1 ], offset : 4096 >, 3 >
102+ %rhsShmem1 = memref.subview %rhsShmem [0 , 0 ][64 , 64 ][1 , 1 ] : !shmemrhs to memref <64 x64 xf16 , strided <[128 , 1 ]>, 3 >
103+ %rhsShmem2 = memref.subview %rhsShmem [32 , 0 ][64 , 64 ][1 , 1 ] : !shmemrhs to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 4096 >, 3 >
103104
104105 // Step 5. Initialize the mbarrier
105106 %9 = nvgpu.mbarrier.create -> !barrierType
@@ -110,8 +111,8 @@ module @mymod {
110111 scf.if %10 {
111112 gpu.printf " [GPU] TMA SIZE %d\0A" %c32768 : index
112113 nvgpu.tma.async.load %d_lhsTensorMap [%c0 , %c0 ], %9 [%c0 ] to %lhsShmem : !lhsTensorMap , !barrierType -> !shmemlhs
113- nvgpu.tma.async.load %d_rhsTensorMap [%c0 , %c0 ], %9 [%c0 ] to %rhsShmem : !rhsTensorMap , !barrierType -> !shmemrhs
114- nvgpu.tma.async.load %d_rhsTensorMap [%c64 , %c0 ], %9 [%c0 ] to %rhsShmem2 : !rhsTensorMap , !barrierType -> memref <128 x 64 x f16 , strided <[128 , 1 ], offset : 4096 >, 3 >
114+ nvgpu.tma.async.load %d_rhsTensorMap [%c0 , %c0 ], %9 [%c0 ] to %rhsShmem1 : !rhsTensorMap , !barrierType -> memref < 64 x 64 x f16 , strided <[ 128 , 1 ]>, 3 >
115+ nvgpu.tma.async.load %d_rhsTensorMap [%c64 , %c0 ], %9 [%c0 ] to %rhsShmem2 : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 4096 >, 3 >
115116 nvgpu.mbarrier.arrive.expect_tx %9 [%c0 ], %c32768 : !barrierType
116117 } else {
117118 nvgpu.mbarrier.arrive.expect_tx %9 [%c0 ], %c0 : !barrierType
0 commit comments