11
11
from mlir .extras .dialects .ext import arith , memref , scf , gpu , linalg , transform
12
12
from mlir .dialects .transform import any_op_t
13
13
from mlir .extras .dialects .ext .func import func
14
- from mlir .extras .dialects .ext .nvgpu import tensormap_descriptor
14
+ from mlir .dialects .nvgpu import (
15
+ TensorMapDescriptorType ,
16
+ TensorMapSwizzleKind ,
17
+ TensorMapL2PromoKind ,
18
+ TensorMapOOBKind ,
19
+ TensorMapInterleaveKind ,
20
+ )
15
21
from mlir .dialects .transform .structured import MatchInterfaceEnum
16
22
from mlir .dialects .memref import cast
17
23
from mlir .dialects .nvgpu import tma_create_descriptor
@@ -37,7 +43,13 @@ def create_tensor_map(
37
43
crd0 = arith .constant (64 , index = True )
38
44
crd1 = arith .constant (128 , index = True )
39
45
device_ptr_2d_unranked = cast (T .memref (element_type = T .f32 ()), device_ptr_2d )
40
- tensor_map_2d = tensormap_descriptor (T .memref (32 , 32 , T .f32 (), memory_space = 3 ))
46
+ tensor_map_2d = TensorMapDescriptorType .get (
47
+ T .memref (32 , 32 , T .f32 (), memory_space = 3 ),
48
+ TensorMapSwizzleKind .SWIZZLE_NONE ,
49
+ TensorMapL2PromoKind .L2PROMO_NONE ,
50
+ TensorMapOOBKind .OOB_NAN ,
51
+ TensorMapInterleaveKind .INTERLEAVE_NONE ,
52
+ )
41
53
tensor_map_2d = tma_create_descriptor (
42
54
tensor_map_2d , device_ptr_2d_unranked , [crd0 , crd1 ]
43
55
)
@@ -187,8 +199,7 @@ def payload():
187
199
compute_linspace_val .emit ()
188
200
189
201
@func
190
- def printMemrefF32 (x : T .memref (T .f32 ())):
191
- ...
202
+ def printMemrefF32 (x : T .memref (T .f32 ())): ...
192
203
193
204
printMemrefF32_ .append (printMemrefF32 )
194
205
@@ -408,6 +419,7 @@ def main(module: any_op_t()):
408
419
409
420
CUDA_RUNTIME_LIB_PATH = Path (_mlir_libs .__file__ ).parent / f"libmlir_cuda_runtime.so"
410
421
422
+
411
423
# based on https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6
412
424
@pytest .mark .skipif (not CUDA_RUNTIME_LIB_PATH .exists (), reason = "no cuda library" )
413
425
def test_transform_mma_sync_matmul_f16_f16_accum_run (ctx : MLIRContext , capfd ):
@@ -536,8 +548,7 @@ def payload():
536
548
compute_linspace_val .emit ()
537
549
538
550
@func
539
- def printMemrefF32 (x : T .memref (T .f32 ())):
540
- ...
551
+ def printMemrefF32 (x : T .memref (T .f32 ())): ...
541
552
542
553
printMemrefF32_ .append (printMemrefF32 )
543
554
0 commit comments