diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 9f76f5d1d8c92..70e3e45c225db 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -807,7 +807,7 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { // MathToXeVM //===----------------------------------------------------------------------===// -def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { +def ConvertMathToXeVM : Pass<"convert-math-to-xevm"> { let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents"; let description = [{ diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir index d76627bb4201c..c61640c2afc4f 100644 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -3,6 +3,15 @@ // RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \ // RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH' +// RUN: mlir-opt --pass-pipeline="builtin.module(convert-math-to-xevm)" %s \ +// RUN: | FileCheck %s -check-prefixes='CHECK-MODULE,CHECK-ENTIRE-MODULE' +// RUN: mlir-opt --pass-pipeline="builtin.module(gpu.module(convert-math-to-xevm))" %s \ +// RUN: | FileCheck %s -check-prefixes='CHECK-MODULE,CHECK-ONLY-GPU' + +// This test: +// - check that MathToXeVM converts fastmath math/arith ops properly; +// - check that MathToXeVM handles nested modules while respecting pass manager. + module @test_module { // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 @@ -152,4 +161,39 @@ module @test_module { return } + + // Check that MathToXeVM handles nested modules while respecting pass manager: + + // CHECK-ENTIRE-MODULE: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + // CHECK-ONLY-GPU-NOT: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + + // CHECK-MODULE-LABEL: @test_gpu + gpu.module @test_gpu { + // CHECK-MODULE: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + gpu.func @exp_gpu() { + %c1_f32 = arith.constant 1. : f32 + + // CHECK-MODULE: math.exp + %exp_normal_f32 = math.exp %c1_f32 : f32 + + // CHECK-MODULE: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_fast_f32 = math.exp %c1_f32 fastmath : f32 + + gpu.return + } + } + + // CHECK-MODULE-LABEL: @exp_func + func.func @exp_func() { + %c1_f32 = arith.constant 1. : f32 + + // CHECK-MODULE: math.exp + %exp_normal_f32 = math.exp %c1_f32 : f32 + + // CHECK-ENTIRE-MODULE: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK-ONLY-GPU: math.exp + %exp_fast_f32 = math.exp %c1_f32 fastmath : f32 + + return + } }