Skip to content

Commit 6c0aa6f

Browse files
committed
address review, remove dump ir from decorator
1 parent d1e4ae5 commit 6c0aa6f

File tree

8 files changed

+160
-167
lines changed

8 files changed

+160
-167
lines changed

mlir/test/Examples/NVGPU/Ch0.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
3-
# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
43
# RUN: then %PYTHON %s | FileCheck %s; \
5-
# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
67

78
# ===----------------------------------------------------------------------===//
89
# Chapter 0 : Hello World
@@ -21,12 +22,10 @@
2122
from tools.nvdsl import *
2223

2324

24-
dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
25-
2625
# 1. The decorator generates a MLIR func.func.
2726
# Everything inside the Python function becomes the body of the func.
2827
# The decorator also translates `alpha` to an `index` type.
29-
@NVDSL.mlir_func(dump_only)
28+
@NVDSL.mlir_func
3029
def main(alpha):
3130
# 2. The decorator generates a MLIR gpu.launch.
3231
# Everything inside the Python function becomes the body of the gpu.launch.

mlir/test/Examples/NVGPU/Ch1.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
3-
# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
43
# RUN: then %PYTHON %s | FileCheck %s; \
5-
# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
67

78
# ===----------------------------------------------------------------------===//
89
# Chapter 1 : 2D Saxpy
@@ -22,9 +23,9 @@
2223
from tools.nvdsl import *
2324
import numpy as np
2425

25-
dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
2626

27-
@NVDSL.mlir_func(dump_only)
27+
28+
@NVDSL.mlir_func
2829
def saxpy(x, y, alpha):
2930
# 1. Use MLIR GPU dialect to allocate and copy memory
3031
token_ty = gpu.AsyncTokenType.get()
@@ -63,7 +64,7 @@ def saxpy_kernel():
6364

6465
saxpy(x, y, alpha)
6566

66-
if not dump_only:
67+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
6768
# 4. Verify MLIR with reference computation
6869
ref = np.ones((M, N), np.float32)
6970
ref += x * alpha

mlir/test/Examples/NVGPU/Ch2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
3-
# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
43
# RUN: then %PYTHON %s | FileCheck %s; \
5-
# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
67

78
# ===----------------------------------------------------------------------===//
89
# Chapter 2 : 2D Saxpy with TMA
@@ -27,9 +28,7 @@
2728
from mlir.extras import types as T
2829
import numpy as np
2930

30-
dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
31-
32-
@NVDSL.mlir_func(dump_only)
31+
@NVDSL.mlir_func
3332
def saxpy(x, y, alpha):
3433
token_ty = gpu.AsyncTokenType.get()
3534
t1 = gpu.wait(token_ty, [])
@@ -89,7 +88,7 @@ def saxpy_tma_kernel():
8988
y = np.ones((M, N), np.float32)
9089
saxpy(x, y, alpha)
9190

92-
if not dump_only:
91+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
9392
# 4. Verify MLIR with reference computation
9493
ref = np.ones((M, N), np.float32)
9594
ref += x * alpha

mlir/test/Examples/NVGPU/Ch3.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
3-
# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
43
# RUN: then %PYTHON %s | FileCheck %s; \
5-
# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
67

78
# ===----------------------------------------------------------------------===//
89
# Chapter 3 : GEMM 128x128x64 with Tensor Core
@@ -24,8 +25,6 @@
2425
from mlir.extras import types as T
2526
import numpy as np
2627

27-
dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
28-
2928
def tma_load(
3029
mbar_group: Mbarriers,
3130
a_tma: TMA,
@@ -61,7 +60,7 @@ def tma_load(
6160
b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
6261

6362

64-
@NVDSL.mlir_func(dump_only)
63+
@NVDSL.mlir_func
6564
def gemm_128_128_64(a, b, d):
6665
token_ty = gpu.AsyncTokenType.get()
6766
t1 = gpu.wait(token_ty, [])
@@ -127,7 +126,7 @@ def gemm_tma_kernel():
127126
d = np.zeros((M, N), np.float32)
128127
gemm_128_128_64(a, b, d)
129128

130-
if not dump_only:
129+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
131130
# Verify MLIR program with reference computation in python
132131
ref_d = a.astype(np.float16) @ b.astype(np.float16)
133132
np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)

mlir/test/Examples/NVGPU/Ch4.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
3-
# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
43
# RUN: then %PYTHON %s | FileCheck %s; \
5-
# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
66

77

88
# ===----------------------------------------------------------------------===//
@@ -51,7 +51,7 @@
5151
from tools.nvdsl import *
5252
import numpy as np
5353

54-
dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
54+
5555

5656
def partition_shape():
5757
"""
@@ -261,7 +261,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
261261
# a -> memref<MxKxf16>
262262
# b -> memref<NxKf16>
263263
# d -> memref<MxNxf32>
264-
@NVDSL.mlir_func(dump_only)
264+
@NVDSL.mlir_func
265265
def gemm_multistage(a, b, d, num_stages):
266266
token_ty = gpu.AsyncTokenType.get()
267267
t1 = gpu.wait(token_ty, [])
@@ -318,8 +318,7 @@ def gemm_multistage_kernel():
318318

319319
gemm_multistage(a, b, d, num_stages=7)
320320

321-
322-
if not dump_only:
321+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
323322
# Verify MLIR with reference computation
324323
ref_d = a.astype(np.float16) @ b.astype(np.float16)
325324
np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)

mlir/test/Examples/NVGPU/Ch5.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2-
# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
3-
# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
2+
# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
43
# RUN: then %PYTHON %s | FileCheck %s; \
5-
# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
4+
# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
5+
# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
6+
67

78
# ===----------------------------------------------------------------------===//
89
# Chapter 5 : Warp Specialized GEMM with Tensor Core
@@ -50,7 +51,7 @@
5051
from tools.nvdsl import *
5152
import numpy as np
5253

53-
dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
54+
5455

5556
def partition_shape():
5657
"""
@@ -254,7 +255,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
254255
scf.yield_([])
255256

256257

257-
@NVDSL.mlir_func(dump_only)
258+
@NVDSL.mlir_func
258259
def gemm_warp_specialized(a, b, d, num_stages):
259260
token_ty = gpu.AsyncTokenType.get()
260261
t1 = gpu.wait(token_ty, [])
@@ -315,7 +316,7 @@ def gemm_warp_specialized_kernel():
315316

316317
gemm_warp_specialized(a, b, d, num_stages=7)
317318

318-
if not dump_only:
319+
if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
319320
# Verify MLIR with reference computation
320321
ref_d = a.astype(np.float16) @ b.astype(np.float16)
321322
np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
config.unsupported = False
2-
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
2+
if not config.enable_cuda_runner:
33
config.unsupported = True
44

0 commit comments

Comments
 (0)