Skip to content

Commit 299df7e

Browse files
authored
[NVGPU] Fix nvdsl examples (#156830)
This PR aims at fixing the nvdsl examples which got a bit out of sync not being tested in the CI. The fixed bugs were related to the following PRs: - move to nanobind #118583 - split gpu module initialization #135478
1 parent 0ade260 commit 299df7e

File tree

9 files changed

+603
-47
lines changed

9 files changed

+603
-47
lines changed

mlir/test/Examples/NVGPU/Ch0.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: %PYTHON %s | FileCheck %s
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
3+
# RUN: then %PYTHON %s | FileCheck %s; \
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
37

48
# ===----------------------------------------------------------------------===//
59
# Chapter 0 : Hello World
@@ -43,8 +47,24 @@ def kernel():
4347
# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
4448
main(alpha)
4549

46-
4750
# CHECK: GPU thread 0 has 100
4851
# CHECK: GPU thread 1 has 101
4952
# CHECK: GPU thread 2 has 102
5053
# CHECK: GPU thread 3 has 103
54+
55+
# DUMPIR: func.func @main(%arg0: index) attributes {llvm.emit_c_interface} {
56+
# DUMPIR: %[[C1:.*]] = arith.constant 1 : index
57+
# DUMPIR: %[[C1_0:.*]] = arith.constant 1 : index
58+
# DUMPIR: %[[C1_1:.*]] = arith.constant 1 : index
59+
# DUMPIR: %[[C4:.*]] = arith.constant 4 : index
60+
# DUMPIR: %[[C1_2:.*]] = arith.constant 1 : index
61+
# DUMPIR: %[[C1_3:.*]] = arith.constant 1 : index
62+
# DUMPIR: %[[C0_I32:.*]] = arith.constant 0 : i32
63+
# DUMPIR: gpu.launch blocks(%arg1, %arg2, %arg3) in (%arg7 = %[[C1]], %arg8 = %[[C1_0]], %arg9 = %[[C1_1]]) threads(%arg4, %arg5, %arg6) in (%arg10 = %[[C4]], %arg11 = %[[C1_2]], %arg12 = %[[C1_3]]) dynamic_shared_memory_size %[[C0_I32]] {
64+
# DUMPIR: %[[TIDX:.*]] = gpu.thread_id x
65+
# DUMPIR: %[[MYVAL:.*]] = arith.addi %arg0, %[[TIDX]] : index
66+
# DUMPIR: gpu.printf "GPU thread %llu has %llu\0A", %[[TIDX]], %[[MYVAL]] : index, index
67+
# DUMPIR: gpu.terminator
68+
# DUMPIR: }
69+
# DUMPIR: return
70+
# DUMPIR: }

mlir/test/Examples/NVGPU/Ch1.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: %PYTHON %s | FileCheck %s
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
3+
# RUN: then %PYTHON %s | FileCheck %s; \
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
37

48
# ===----------------------------------------------------------------------===//
59
# Chapter 1 : 2D Saxpy
@@ -56,11 +60,43 @@ def saxpy_kernel():
5660
alpha = 2.0
5761
x = np.random.randn(M, N).astype(np.float32)
5862
y = np.ones((M, N), np.float32)
63+
5964
saxpy(x, y, alpha)
6065

61-
# 4. Verify MLIR with reference computation
62-
ref = np.ones((M, N), np.float32)
63-
ref += x * alpha
64-
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
65-
print("PASS")
66+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
67+
# 4. Verify MLIR with reference computation
68+
ref = np.ones((M, N), np.float32)
69+
ref += x * alpha
70+
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
71+
print("PASS")
6672
# CHECK-NOT: Mismatched elements
73+
# CHECK: PASS
74+
75+
# DUMPIR: func.func @saxpy(%arg0: memref<256x32xf32>, %arg1: memref<256x32xf32>, %arg2: f32) attributes {llvm.emit_c_interface} {
76+
# DUMPIR: %[[WAIT0:.*]] = gpu.wait async
77+
# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
78+
# DUMPIR: %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
79+
# DUMPIR: %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %arg0 : memref<256x32xf32>, memref<256x32xf32>
80+
# DUMPIR: %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %arg1 : memref<256x32xf32>, memref<256x32xf32>
81+
# DUMPIR: %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
82+
# DUMPIR: %[[C256:.*]] = arith.constant 256 : index
83+
# DUMPIR: %[[C1:.*]] = arith.constant 1 : index
84+
# DUMPIR: %[[C1_2:.*]] = arith.constant 1 : index
85+
# DUMPIR: %[[C32:.*]] = arith.constant 32 : index
86+
# DUMPIR: %[[C1_3:.*]] = arith.constant 1 : index
87+
# DUMPIR: %[[C1_4:.*]] = arith.constant 1 : index
88+
# DUMPIR: %[[C0_I32:.*]] = arith.constant 0 : i32
89+
# DUMPIR: gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C256]], %arg10 = %[[C1]], %arg11 = %[[C1_2]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C32]], %arg13 = %[[C1_3]], %arg14 = %[[C1_4]]) dynamic_shared_memory_size %[[C0_I32]] {
90+
# DUMPIR: %[[BLOCKID:.*]] = gpu.block_id x
91+
# DUMPIR: %[[THREADID:.*]] = gpu.thread_id x
92+
# DUMPIR: %[[LD0:.*]] = memref.load %[[MEMREF]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
93+
# DUMPIR: %[[LD1:.*]] = memref.load %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
94+
# DUMPIR: %[[MUL:.*]] = arith.mulf %[[LD0]], %arg2 : f32
95+
# DUMPIR: %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
96+
# DUMPIR: memref.store %[[ADD]], %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
97+
# DUMPIR: gpu.terminator
98+
# DUMPIR: }
99+
# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg1, %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
100+
# DUMPIR: %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
101+
# DUMPIR: return
102+
# DUMPIR: }

mlir/test/Examples/NVGPU/Ch2.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: %PYTHON %s | FileCheck %s
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
3+
# RUN: then %PYTHON %s | FileCheck %s; \
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
37

48
# ===----------------------------------------------------------------------===//
59
# Chapter 2 : 2D Saxpy with TMA
@@ -85,9 +89,75 @@ def saxpy_tma_kernel():
8589
y = np.ones((M, N), np.float32)
8690
saxpy(x, y, alpha)
8791

88-
# 4. Verify MLIR with reference computation
89-
ref = np.ones((M, N), np.float32)
90-
ref += x * alpha
91-
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
92-
print("PASS")
92+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
93+
# 4. Verify MLIR with reference computation
94+
ref = np.ones((M, N), np.float32)
95+
ref += x * alpha
96+
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
97+
print("PASS")
9398
# CHECK-NOT: Mismatched elements
99+
# CHECK: PASS
100+
101+
# DUMPIR: func.func @saxpy(%arg0: memref<256x32xf32>, %arg1: memref<256x32xf32>, %arg2: f32) attributes {llvm.emit_c_interface} {
102+
# DUMPIR: %[[WAIT0:.*]] = gpu.wait async
103+
# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
104+
# DUMPIR: %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
105+
# DUMPIR: %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %arg0 : memref<256x32xf32>, memref<256x32xf32>
106+
# DUMPIR: %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %arg1 : memref<256x32xf32>, memref<256x32xf32>
107+
# DUMPIR: %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
108+
# DUMPIR: %[[CAST:.*]] = memref.cast %[[MEMREF]] : memref<256x32xf32> to memref<*xf32>
109+
# DUMPIR: %[[C1:.*]] = arith.constant 1 : index
110+
# DUMPIR: %[[C32:.*]] = arith.constant 32 : index
111+
# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST]] box[%[[C1]], %[[C32]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
112+
# DUMPIR: %[[CAST2:.*]] = memref.cast %[[MEMREF0]] : memref<256x32xf32> to memref<*xf32>
113+
# DUMPIR: %[[C1_3:.*]] = arith.constant 1 : index
114+
# DUMPIR: %[[C32_4:.*]] = arith.constant 32 : index
115+
# DUMPIR: %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST2]] box[%[[C1_3]], %[[C32_4]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
116+
# DUMPIR: %[[C256:.*]] = arith.constant 256 : index
117+
# DUMPIR: %[[C1_5:.*]] = arith.constant 1 : index
118+
# DUMPIR: %[[C1_6:.*]] = arith.constant 1 : index
119+
# DUMPIR: %[[C32_7:.*]] = arith.constant 32 : index
120+
# DUMPIR: %[[C1_8:.*]] = arith.constant 1 : index
121+
# DUMPIR: %[[C1_9:.*]] = arith.constant 1 : index
122+
# DUMPIR: %[[C256_I32:.*]] = arith.constant 256 : i32
123+
# DUMPIR: gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C256]], %arg10 = %[[C1_5]], %arg11 = %[[C1_6]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C32_7]], %arg13 = %[[C1_8]], %arg14 = %[[C1_9]]) dynamic_shared_memory_size %[[C256_I32]] {
124+
# DUMPIR: %[[BLOCKID:.*]] = gpu.block_id x
125+
# DUMPIR: %[[THREADID:.*]] = gpu.thread_id x
126+
# DUMPIR: %[[C0:.*]] = arith.constant 0 : index
127+
# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index
128+
# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
129+
# DUMPIR: %[[C0_10:.*]] = arith.constant 0 : index
130+
# DUMPIR: %[[C1_11:.*]] = arith.constant 1 : index
131+
# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_10]]], %[[C1_11]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
132+
# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
133+
# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index
134+
# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_12]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
135+
# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
136+
# DUMPIR: %[[C128:.*]] = arith.constant 128 : index
137+
# DUMPIR: %[[VIEW_13:.*]] = memref.view %[[DSM1]][%[[C128]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
138+
# DUMPIR: %[[C0_14:.*]] = arith.constant 0 : index
139+
# DUMPIR: %[[C0_15:.*]] = arith.constant 0 : index
140+
# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%[[C0_15]], %[[BLOCKID]]], %[[MB]][%[[C0_14]]] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
141+
# DUMPIR: %[[C0_16:.*]] = arith.constant 0 : index
142+
# DUMPIR: %[[C0_17:.*]] = arith.constant 0 : index
143+
# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C0_17]], %[[BLOCKID]]], %[[MB]][%[[C0_16]]] to %[[VIEW_13]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
144+
# DUMPIR: %[[C0_18:.*]] = arith.constant 0 : index
145+
# DUMPIR: %[[C256_19:.*]] = arith.constant 256 : index
146+
# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_18]]], %[[C256_19]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
147+
# DUMPIR: %[[C0_20:.*]] = arith.constant 0 : index
148+
# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index
149+
# DUMPIR: %[[FALSE:.*]] = arith.constant false
150+
# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_20]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
151+
# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index
152+
# DUMPIR: %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %[[THREADID]]] : memref<1x32xf32, #gpu.address_space<workgroup>>
153+
# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index
154+
# DUMPIR: %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %[[THREADID]]] : memref<1x32xf32, #gpu.address_space<workgroup>>
155+
# DUMPIR: %[[MUL:.*]] = arith.mulf %[[LD0]], %arg2 : f32
156+
# DUMPIR: %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
157+
# DUMPIR: memref.store %[[ADD]], %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
158+
# DUMPIR: gpu.terminator
159+
# DUMPIR: }
160+
# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg1, %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
161+
# DUMPIR: %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
162+
# DUMPIR: return
163+
# DUMPIR: }

0 commit comments

Comments
 (0)