Skip to content

Commit 79e28d8

Browse files
author
Yang Bai
committed
enable fallback to generic LLVM lowering for math dialect in convert-gpu-to-nvvm pass
1 parent 538c850 commit 79e28d8

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,14 @@ struct LowerGpuOpsToNVVMOpsPass final
383383
LLVMConversionTarget target(getContext());
384384

385385
// Set higher benefit, so patterns will run before generic LLVM lowering.
386+
// Make sure the benefit here is higher than ArithToLLVMDialectInterface and
387+
// MathToLLVMDialectInterface.
386388
populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
387389
/*benefit=*/10);
388390

389391
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
390392
allowedDialects.end());
391393
for (Dialect *dialect : getContext().getLoadedDialects()) {
392-
// Skip math patterns as nvvm needs custom math lowering.
393-
if (isa<math::MathDialect>(dialect))
394-
continue;
395-
396394
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
397395
// Empty `allowedDialectsSet` means all dialects are allowed.
398396
if (!allowedDialectsSet.empty() && !allowed)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
2+
3+
/// Math/arith ops that are not supported by libdevice
4+
/// should be converted by generic LLVM lowering patterns.
5+
6+
gpu.module @generic_llvm_test_module_0 {
7+
// CHECK-LABEL: @arith_add
8+
func.func @arith_add(%left: i64, %right: i64) -> i64 {
9+
// CHECK: llvm.add {{.*}}, {{.*}} : i64
10+
%result = arith.addi %left, %right : i64
11+
return %result : i64
12+
}
13+
}
14+
15+
gpu.module @generic_llvm_test_module_1 {
16+
// CHECK-LABEL: @math_abs_non_i32
17+
func.func @math_abs_non_i32(%arg_i64: i64, %arg_i16: i16, %arg_i8: i8, %arg_i1: i1)
18+
-> (i64, i16, i8, i1) {
19+
// CHECK: "llvm.intr.abs"{{.*}} : (i64) -> i64
20+
%abs_i64 = math.absi %arg_i64 : i64
21+
// CHECK: "llvm.intr.abs"{{.*}} : (i16) -> i16
22+
%abs_i16 = math.absi %arg_i16 : i16
23+
// CHECK: "llvm.intr.abs"{{.*}} : (i8) -> i8
24+
%abs_i8 = math.absi %arg_i8 : i8
25+
// CHECK: "llvm.intr.abs"{{.*}} : (i1) -> i1
26+
%abs_i1 = math.absi %arg_i1 : i1
27+
return %abs_i64, %abs_i16, %abs_i8, %abs_i1 : i64, i16, i8, i1
28+
}
29+
}

0 commit comments

Comments
 (0)