Skip to content

Commit 24fbbcb

Browse files
committed
remove hard dependency on modules for mathtoxevm
1 parent 4f23767 commit 24fbbcb

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
800800
// MathToXeVM
801801
//===----------------------------------------------------------------------===//
802802

803-
def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
803+
def ConvertMathToXeVM : Pass<"convert-math-to-xevm"> {
804804
let summary =
805805
"Convert (fast) math operations to native XeVM/SPIRV equivalents";
806806
let description = [{
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt --pass-pipeline="builtin.module(convert-math-to-xevm)" %s \
2+
// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ENTIRE-MODULE'
3+
// RUN: mlir-opt --pass-pipeline="builtin.module(gpu.module(convert-math-to-xevm))" %s \
4+
// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ONLY-GPU'
5+
//
6+
// Check that MathToXeVM handles nested modules while respecting pass manager.
7+
8+
// CHECK-LABEL: @test_module
9+
module @test_module {
10+
// CHECK-ENTIRE-MODULE: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
11+
// CHECK-ONLY-GPU-NOT: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
12+
13+
// CHECK-LABEL: @test_gpu
14+
gpu.module @test_gpu {
15+
// CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
16+
gpu.func @exp_gpu() {
17+
%c1_f32 = arith.constant 1. : f32
18+
19+
// CHECK: math.exp
20+
%exp_normal_f32 = math.exp %c1_f32 : f32
21+
22+
// CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
23+
%exp_fast_f32 = math.exp %c1_f32 fastmath<afn> : f32
24+
25+
gpu.return
26+
}
27+
}
28+
29+
// CHECK-LABEL: @exp_func
30+
func.func @exp_func() {
31+
%c1_f32 = arith.constant 1. : f32
32+
33+
// CHECK: math.exp
34+
%exp_normal_f32 = math.exp %c1_f32 : f32
35+
36+
// CHECK-ENTIRE-MODULE: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
37+
// CHECK-ONLY-GPU: math.exp
38+
%exp_fast_f32 = math.exp %c1_f32 fastmath<afn> : f32
39+
40+
return
41+
}
42+
}

0 commit comments

Comments
 (0)