11# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2- # RUN: %PYTHON %s | FileCheck %s
2+ # RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
3+ # RUN: then %PYTHON %s | FileCheck %s; \
4+ # RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+ # RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
37
48# ===----------------------------------------------------------------------===//
59# Chapter 3 : GEMM 128x128x64 with Tensor Core
@@ -60,13 +64,13 @@ def tma_load(
6064@NVDSL .mlir_func
6165def gemm_128_128_64 (a , b , d ):
6266 token_ty = gpu .AsyncTokenType .get ()
63- t1 = gpu .wait (token_ty , [])
67+ t1 = gpu .wait ([])
6468 a_dev , t2 = gpu .alloc (a .type , token_ty , [t1 ], [], [])
6569 b_dev , t3 = gpu .alloc (b .type , token_ty , [t2 ], [], [])
6670 d_dev , t4 = gpu .alloc (d .type , token_ty , [t3 ], [], [])
6771 t5 = gpu .memcpy (token_ty , [t4 ], a_dev , a )
6872 t6 = gpu .memcpy (token_ty , [t5 ], b_dev , b )
69- t7 = gpu .wait (token_ty , [t6 ])
73+ t7 = gpu .wait ([t6 ])
7074
7175 sw = nvgpu .TensorMapSwizzleKind .SWIZZLE_128B
7276 a_tma = TMA ([128 , 64 ], a .type , swizzle = sw )
@@ -111,7 +115,7 @@ def gemm_tma_kernel():
111115 gemm_tma_kernel ()
112116
113117 t8 = gpu .memcpy (token_ty , [t7 ], d , d_dev )
114- gpu .wait (None , [t8 ])
118+ gpu .wait ([t8 ])
115119
116120
117121# Python pass arguments to MLIR
@@ -123,7 +127,73 @@ def gemm_tma_kernel():
123127d = np .zeros ((M , N ), np .float32 )
124128gemm_128_128_64 (a , b , d )
125129
126- ref_d = a .astype (np .float16 ) @ b .astype (np .float16 )
127- np .testing .assert_allclose (d , ref_d , rtol = 5e-03 , atol = 1e-01 )
128- print ("PASS" )
130+ if os .getenv ("MLIR_NVDSL_PRINT_IR" ) != "1" :
131+ # Verify MLIR program with reference computation in python
132+ ref_d = a .astype (np .float16 ) @ b .astype (np .float16 )
133+ np .testing .assert_allclose (d , ref_d , rtol = 5e-03 , atol = 1e-01 )
134+ print ("PASS" )
129135# CHECK-NOT: Mismatched elements
136+ # CHECK: PASS
137+
138+ # DUMPIR: func.func @gemm_128_128_64(%{{.*}}: memref<128x64xf16>, %{{.*}}: memref<64x128xf16>, %[[ARG2:.*]]: memref<128x128xf32>) attributes {llvm.emit_c_interface} {
139+ # DUMPIR: %[[C128:.*]] = arith.constant 128 : index
140+ # DUMPIR: %[[C64:.*]] = arith.constant 64 : index
141+ # DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %{{.*}} box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
142+ # DUMPIR: %[[CAST1:.*]] = memref.cast %{{.*}} : memref<64x128xf16> to memref<*xf16>
143+ # DUMPIR: %[[C64_5:.*]] = arith.constant 64 : index
144+ # DUMPIR: %[[C64_6:.*]] = arith.constant 64 : index
145+ # DUMPIR: %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST1]] box[%[[C64_5]], %[[C64_6]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
146+ # DUMPIR: %[[THREADID:.*]] = gpu.thread_id x
147+ # DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
148+ # DUMPIR: %[[C0:.*]] = arith.constant 0 : index
149+ # DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index
150+ # DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index
151+ # DUMPIR: %[[C1_13:.*]] = arith.constant 1 : index
152+ # DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_12]]], %[[C1_13]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
153+ # DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA0]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
154+ # DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA1]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
155+ # DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
156+ # DUMPIR: %[[C0_14:.*]] = arith.constant 0 : index
157+ # DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_14]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
158+ # DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
159+ # DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index
160+ # DUMPIR: %[[VIEW_15:.*]] = memref.view %[[DSM1]][%[[C16384]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
161+ # DUMPIR: %[[DSM2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
162+ # DUMPIR: %[[C0_16:.*]] = arith.constant 0 : index
163+ # DUMPIR: %[[VIEW_17:.*]] = memref.view %[[DSM2]][%[[C0_16]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
164+ # DUMPIR: %[[DSM3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
165+ # DUMPIR: %[[C16384_18:.*]] = arith.constant 16384 : index
166+ # DUMPIR: %[[VIEW_19:.*]] = memref.view %[[DSM3]][%[[C16384_18]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
167+ # DUMPIR: %[[DSM4:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
168+ # DUMPIR: %[[C24576:.*]] = arith.constant 24576 : index
169+ # DUMPIR: %[[VIEW_20:.*]] = memref.view %[[DSM4]][%[[C24576]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
170+ # DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index
171+ # DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index
172+ # DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_21]]], %[[C32768]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
173+ # DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index
174+ # DUMPIR: %[[C0_23:.*]] = arith.constant 0 : index
175+ # DUMPIR: %[[C0_24:.*]] = arith.constant 0 : index
176+ # DUMPIR: nvgpu.tma.async.load %[[TMA0]][%[[C0_23]], %[[C0_24]]], %[[MB]][%[[C0_22]]] to %[[VIEW_17]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<128x64xf16, #gpu.address_space<workgroup>>
177+ # DUMPIR: %[[C0_25:.*]] = arith.constant 0 : index
178+ # DUMPIR: %[[C0_26:.*]] = arith.constant 0 : index
179+ # DUMPIR: %[[C0_27:.*]] = arith.constant 0 : index
180+ # DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C0_26]], %[[C0_27]]], %[[MB]][%[[C0_25]]] to %[[VIEW_19]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>>
181+ # DUMPIR: %[[C0_28:.*]] = arith.constant 0 : index
182+ # DUMPIR: %[[C64_29:.*]] = arith.constant 64 : index
183+ # DUMPIR: %[[C0_30:.*]] = arith.constant 0 : index
184+ # DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C64_29]], %[[C0_30]]], %[[MB]][%[[C0_28]]] to %[[VIEW_20]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>>
185+ # DUMPIR: %[[C0_31:.*]] = arith.constant 0 : index
186+ # DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index
187+ # DUMPIR: %[[FALSE:.*]] = arith.constant false
188+ # DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_31]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
189+ # DUMPIR: %[[WG_ACC:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
190+ # DUMPIR: %[[GEN0:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW]], %[[TMA0]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
191+ # DUMPIR: %[[GEN1:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_15]], %[[TMA1]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
192+ # DUMPIR: %[[MMA:.*]] = nvgpu.warpgroup.mma %[[GEN0]], %[[GEN1]], %[[WG_ACC]] {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
193+ # DUMPIR: nvgpu.warpgroup.mma.store %[[MMA]], %{{.*}} : <fragmented = vector<128x128xf32>> to memref<128x128xf32>
194+ # DUMPIR: gpu.terminator
195+ # DUMPIR: }
196+ # DUMPIR: %[[CPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG2]], %{{.*}} : memref<128x128xf32>, memref<128x128xf32>
197+ # DUMPIR: gpu.wait async [%[[CPY3]]]
198+ # DUMPIR: return
199+ # DUMPIR: }
0 commit comments