Skip to content

Commit aa08c7b

Browse files
authored
remove tensormap_descriptor (#83)
1 parent d7668eb commit aa08c7b

File tree

2 files changed

+17
-32
lines changed

2 files changed

+17
-32
lines changed

mlir/extras/dialects/ext/nvgpu.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

tests/test_nvgpu_nvvm.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from mlir.extras.dialects.ext import arith, memref, scf, gpu, linalg, transform
1212
from mlir.dialects.transform import any_op_t
1313
from mlir.extras.dialects.ext.func import func
14-
from mlir.extras.dialects.ext.nvgpu import tensormap_descriptor
14+
from mlir.dialects.nvgpu import (
15+
TensorMapDescriptorType,
16+
TensorMapSwizzleKind,
17+
TensorMapL2PromoKind,
18+
TensorMapOOBKind,
19+
TensorMapInterleaveKind,
20+
)
1521
from mlir.dialects.transform.structured import MatchInterfaceEnum
1622
from mlir.dialects.memref import cast
1723
from mlir.dialects.nvgpu import tma_create_descriptor
@@ -37,7 +43,13 @@ def create_tensor_map(
3743
crd0 = arith.constant(64, index=True)
3844
crd1 = arith.constant(128, index=True)
3945
device_ptr_2d_unranked = cast(T.memref(element_type=T.f32()), device_ptr_2d)
40-
tensor_map_2d = tensormap_descriptor(T.memref(32, 32, T.f32(), memory_space=3))
46+
tensor_map_2d = TensorMapDescriptorType.get(
47+
T.memref(32, 32, T.f32(), memory_space=3),
48+
TensorMapSwizzleKind.SWIZZLE_NONE,
49+
TensorMapL2PromoKind.L2PROMO_NONE,
50+
TensorMapOOBKind.OOB_NAN,
51+
TensorMapInterleaveKind.INTERLEAVE_NONE,
52+
)
4153
tensor_map_2d = tma_create_descriptor(
4254
tensor_map_2d, device_ptr_2d_unranked, [crd0, crd1]
4355
)
@@ -187,8 +199,7 @@ def payload():
187199
compute_linspace_val.emit()
188200

189201
@func
190-
def printMemrefF32(x: T.memref(T.f32())):
191-
...
202+
def printMemrefF32(x: T.memref(T.f32())): ...
192203

193204
printMemrefF32_.append(printMemrefF32)
194205

@@ -408,6 +419,7 @@ def main(module: any_op_t()):
408419

409420
CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
410421

422+
411423
# based on https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6
412424
@pytest.mark.skipif(not CUDA_RUNTIME_LIB_PATH.exists(), reason="no cuda library")
413425
def test_transform_mma_sync_matmul_f16_f16_accum_run(ctx: MLIRContext, capfd):
@@ -536,8 +548,7 @@ def payload():
536548
compute_linspace_val.emit()
537549

538550
@func
539-
def printMemrefF32(x: T.memref(T.f32())):
540-
...
551+
def printMemrefF32(x: T.memref(T.f32())): ...
541552

542553
printMemrefF32_.append(printMemrefF32)
543554

0 commit comments

Comments
 (0)