|
1 | 1 | import re
|
| 2 | +import subprocess |
2 | 3 | from pathlib import Path
|
3 | 4 | from textwrap import dedent
|
4 | 5 |
|
|
8 | 9 | from mlir.dialects.memref import cast
|
9 | 10 | from mlir.dialects.nvgpu import (
|
10 | 11 | TensorMapDescriptorType,
|
11 |
| - TensorMapSwizzleKind, |
| 12 | + TensorMapInterleaveKind, |
12 | 13 | TensorMapL2PromoKind,
|
13 | 14 | TensorMapOOBKind,
|
14 |
| - TensorMapInterleaveKind, |
| 15 | + TensorMapSwizzleKind, |
| 16 | + tma_create_descriptor, |
15 | 17 | )
|
16 |
| -from mlir.dialects.nvgpu import tma_create_descriptor |
17 | 18 | from mlir.dialects.transform import any_op_t
|
18 | 19 | from mlir.dialects.transform.extras import named_sequence
|
19 | 20 | from mlir.dialects.transform.structured import MatchInterfaceEnum
|
20 | 21 | from mlir.ir import StringAttr, UnitAttr
|
21 | 22 |
|
22 | 23 | from mlir import _mlir_libs
|
23 | 24 | from mlir.extras.ast.canonicalize import canonicalize
|
24 |
| -from mlir.extras.dialects.ext import arith, memref, scf, gpu, linalg, transform, nvgpu |
| 25 | +from mlir.extras.dialects.ext import arith, gpu, linalg, memref, nvgpu, scf, transform |
25 | 26 | from mlir.extras.dialects.ext.func import func
|
26 | 27 | from mlir.extras.dialects.ext.gpu import smem_space
|
27 | 28 | from mlir.extras.dialects.ext.llvm import llvm_ptr_t
|
28 |
| -from mlir.extras.runtime.passes import run_pipeline, Pipeline |
| 29 | +from mlir.extras.runtime.passes import Pipeline, run_pipeline |
29 | 30 | from mlir.extras.runtime.refbackend import LLVMJITBackend
|
30 | 31 |
|
31 | 32 | # noinspection PyUnresolvedReferences
|
32 |
| -from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext |
| 33 | +from mlir.extras.testing import MLIRContext, filecheck, mlir_ctx as ctx |
33 | 34 | from mlir.extras.util import find_ops
|
34 | 35 |
|
35 | 36 | # needed since the fix isn't defined here nor conftest.py
|
@@ -200,7 +201,8 @@ def payload():
|
200 | 201 | compute_linspace_val.emit()
|
201 | 202 |
|
202 | 203 | @func
|
203 |
| - def printMemrefF32(x: T.memref(T.f32())): ... |
| 204 | + def printMemrefF32(x: T.memref(T.f32())): |
| 205 | + ... |
204 | 206 |
|
205 | 207 | printMemrefF32_.append(printMemrefF32)
|
206 | 208 |
|
@@ -421,8 +423,15 @@ def main(module: any_op_t()):
|
421 | 423 | CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
|
422 | 424 |
|
423 | 425 |
|
| 426 | +NVIDIA_GPU = False |
| 427 | +try: |
| 428 | + subprocess.check_output("nvidia-smi") |
| 429 | + NVIDIA_GPU = True |
| 430 | +except Exception: |
| 431 | + print("No Nvidia GPU in system!") |
| 432 | + |
424 | 433 | # based on https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6
|
425 |
| -@pytest.mark.skipif(not CUDA_RUNTIME_LIB_PATH.exists(), reason="no cuda library") |
| 434 | +@pytest.mark.skipif(not NVIDIA_GPU, reason="no cuda library") |
426 | 435 | def test_transform_mma_sync_matmul_f16_f16_accum_run(ctx: MLIRContext, capfd):
|
427 | 436 | range_ = scf.range_
|
428 | 437 |
|
@@ -549,7 +558,8 @@ def payload():
|
549 | 558 | compute_linspace_val.emit()
|
550 | 559 |
|
551 | 560 | @func
|
552 |
| - def printMemrefF32(x: T.memref(T.f32())): ... |
| 561 | + def printMemrefF32(x: T.memref(T.f32())): |
| 562 | + ... |
553 | 563 |
|
554 | 564 | printMemrefF32_.append(printMemrefF32)
|
555 | 565 |
|
|
0 commit comments