Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,16 +383,14 @@ struct LowerGpuOpsToNVVMOpsPass final
LLVMConversionTarget target(getContext());

// Set higher benefit, so patterns will run before generic LLVM lowering.
// Make sure the benefit here is higher than ArithToLLVMDialectInterface and
// MathToLLVMDialectInterface.
populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
/*benefit=*/10);

llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
allowedDialects.end());
for (Dialect *dialect : getContext().getLoadedDialects()) {
// Skip math patterns as nvvm needs custom math lowering.
if (isa<math::MathDialect>(dialect))
continue;

bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
// Empty `allowedDialectsSet` means all dialects are allowed.
if (!allowedDialectsSet.empty() && !allowed)
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-generic-llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s

/// Math/arith ops that are not supported by libdevice
/// should be converted by generic LLVM lowering patterns.

gpu.module @generic_llvm_test_module_0 {
// CHECK-LABEL: @arith_add
func.func @arith_add(%left: i64, %right: i64) -> i64 {
// CHECK: llvm.add {{.*}}, {{.*}} : i64
%result = arith.addi %left, %right : i64
return %result : i64
}
}

gpu.module @generic_llvm_test_module_1 {
// CHECK-LABEL: @math_abs_non_i32
func.func @math_abs_non_i32(%arg_i64: i64, %arg_i16: i16, %arg_i8: i8, %arg_i1: i1)
-> (i64, i16, i8, i1) {
// CHECK: "llvm.intr.abs"{{.*}} : (i64) -> i64
%abs_i64 = math.absi %arg_i64 : i64
// CHECK: "llvm.intr.abs"{{.*}} : (i16) -> i16
%abs_i16 = math.absi %arg_i16 : i16
// CHECK: "llvm.intr.abs"{{.*}} : (i8) -> i8
%abs_i8 = math.absi %arg_i8 : i8
// CHECK: "llvm.intr.abs"{{.*}} : (i1) -> i1
%abs_i1 = math.absi %arg_i1 : i1
return %abs_i64, %abs_i16, %abs_i8, %abs_i1 : i64, i16, i8, i1
}
}