Skip to content

Commit 9f114af

Browse files
authored
[MLIR][ROCDL] Convert math::fpowi to ROCDL call (#122640)
* Have to relax static assert to allow reuse of existing template patterns for conversion.
1 parent d90a427 commit 9f114af

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
5757
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
5858
"expected single result op");
5959

60-
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
61-
SourceOp>::value,
62-
"expected op with same operand and result types");
60+
if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
61+
SourceOp>::value) {
62+
assert(op->getNumOperands() > 0 &&
63+
"expected op to take at least one operand");
64+
assert(op->getResultTypes().front() == op->getOperand(0).getType() &&
65+
"expected op with same operand and result types");
66+
}
6367

6468
if (!op->template getParentOfType<FunctionOpInterface>()) {
6569
return rewriter.notifyMatchFailure(

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ void mlir::populateMathToROCDLConversionPatterns(
5757
// Handled by mathToLLVM: math::FmaOp
5858
// Handled by mathToLLVM: math::LogOp (32-bit only)
5959
// FIXME: math::IPowIOp
60-
// FIXME: math::FPowIOp
6160
// Handled by mathToLLVM: math::RoundEvenOp
6261
// Handled by mathToLLVM: math::RoundOp
6362
// Handled by mathToLLVM: math::SqrtOp
@@ -114,6 +113,8 @@ void mlir::populateMathToROCDLConversionPatterns(
114113
"__ocml_tan_f64", "__ocml_tan_f16");
115114
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
116115
"__ocml_erf_f64", "__ocml_erf_f16");
116+
populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
117+
"__ocml_pown_f64", "__ocml_pown_f16");
117118
// Single arith pattern that needs a ROCDL call, probably not
118119
// worth creating a separate pass for it.
119120
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,24 @@ module @test_module {
484484

485485
// -----
486486

487+
module @test_module {
488+
// CHECK: llvm.func @__ocml_pown_f16(f16, i32) -> f16
489+
// CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
490+
// CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
491+
// CHECK-LABEL: func @math_fpowi
492+
func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
493+
// CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
494+
%0 = math.fpowi %arg0, %arg3 : f16, i32
495+
// CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
496+
%1 = math.fpowi %arg1, %arg3 : f32, i32
497+
// CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
498+
%2 = math.fpowi %arg2, %arg3 : f64, i32
499+
return %0, %1, %2 : f16, f32, f64
500+
}
501+
}
502+
503+
// -----
504+
487505
// Math operation not inside function
488506
// Ensure it not crash
489507

0 commit comments

Comments
 (0)