From 7c849d7695247c9222cb8dd73b66aa35c328e650 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 17 Jan 2025 17:29:34 -0600 Subject: [PATCH 1/3] [MLIR][Math][GPU] Add lowering of absi and fpowi to libdevice --- .../GPUCommon/OpToFuncCallLowering.h | 188 +++++++++++++----- .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 19 ++ .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 24 +++ 3 files changed, 181 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 46fd182346b3b..bbfcdaf91205c 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -16,37 +16,11 @@ namespace mlir { -/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or -/// `f32ApproxFunc` or `f16Func` depending on the element type and the -/// fastMathFlag of that Op. The function declaration is added in case it was -/// not added before. -/// -/// If the input values are of bf16 type (or f16 type if f16Func is empty), the -/// value is first casted to f32, the function called and then the result casted -/// back. -/// -/// Example with NVVM: -/// %exp_f32 = math.exp %arg_f32 : f32 -/// -/// will be transformed into -/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 -/// -/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers -/// to the approximate calculation function. -/// -/// Also example with NVVM: -/// %exp_f32 = math.exp %arg_f32 fastmath : f32 -/// -/// will be transformed into -/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 -template -struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { +template +struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { public: - explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, - StringRef f32Func, StringRef f64Func, - StringRef f32ApproxFunc, StringRef f16Func) - : ConvertOpToLLVMPattern(lowering), f32Func(f32Func), - f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {} + explicit OpToFuncCallLoweringBase(const LLVMTypeConverter &lowering) + : ConvertOpToLLVMPattern(lowering) {} LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, @@ -72,13 +46,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { SmallVector castedOperands; for (Value operand : adaptor.getOperands()) - castedOperands.push_back(maybeCast(operand, rewriter)); + castedOperands.push_back( + ((const DerivedTy *)this)->maybeCast(operand, rewriter)); Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); StringRef funcName = - getFunctionName(cast(funcType).getReturnType(), - op.getFastmath()); + ((const DerivedTy *)this) + ->getFunctionName( + cast(funcType).getReturnType(), op); if (funcName.empty()) return failure(); @@ -99,6 +75,61 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { } private: + Type getFunctionType(Type resultType, ValueRange operands) const { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); + } + + LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, + Operation *op) const { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + mlir::OpBuilder b(op->getParentOfType()); + return b.create(op->getLoc(), funcName, funcType); + } +}; + +/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or +/// `f32ApproxFunc` or `f16Func` depending on the element type and the +/// fastMathFlag of that Op. The function declaration is added in case it was +/// not added before. +/// +/// If the input values are of bf16 type (or f16 type if f16Func is empty), the +/// value is first casted to f32, the function called and then the result casted +/// back. +/// +/// Example with NVVM: +/// %exp_f32 = math.exp %arg_f32 : f32 +/// +/// will be transformed into +/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 +/// +/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers +/// to the approximate calculation function. +/// +/// Also example with NVVM: +/// %exp_f32 = math.exp %arg_f32 fastmath : f32 +/// +/// will be transformed into +/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 +template +struct OpToFuncCallLowering + : public OpToFuncCallLoweringBase> { +public: + explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef f32Func, StringRef f64Func, + StringRef f32ApproxFunc, StringRef f16Func) + : OpToFuncCallLoweringBase>( + lowering), + f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), + f16Func(f16Func) {} + Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); if (!isa(type)) @@ -112,12 +143,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); } - Type getFunctionType(Type resultType, ValueRange operands) const { - SmallVector operandTypes(operands.getTypes()); - return LLVM::LLVMFunctionType::get(resultType, operandTypes); - } - - StringRef getFunctionName(Type type, arith::FastMathFlags flag) const { + StringRef getFunctionName(Type type, SourceOp op) const { + arith::FastMathFlags flag = op.getFastmath(); if (isa(type)) return f16Func; if (isa(type)) { @@ -132,23 +159,84 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { return ""; } - LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, - Operation *op) const { - using LLVM::LLVMFuncOp; + const std::string f32Func; + const std::string f64Func; + const std::string f32ApproxFunc; + const std::string f16Func; +}; - auto funcAttr = StringAttr::get(op->getContext(), funcName); - Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); - if (funcOp) - return cast(*funcOp); +/// Rewriting that replace SourceOp with a CallOp to `i32Func` +/// The function declaration is added in case it was not added before. +/// This assumes that all types integral. +/// +/// Example with NVVM: +/// %abs_i32 = math.iabs %arg_i32 : i32 +/// +/// will be transformed into +/// llvm.call @__nv_abs(%arg_i32) : (i32) -> i32 +/// +template +struct IntOpToFuncCallLowering + : public OpToFuncCallLoweringBase> { +public: + explicit IntOpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef i32Func) + : OpToFuncCallLoweringBase>( + lowering), + i32Func(i32Func) {} - mlir::OpBuilder b(op->getParentOfType()); - return b.create(op->getLoc(), funcName, funcType); + Value maybeCast(Value operand, PatternRewriter &rewriter) const { + return operand; + } + + StringRef getFunctionName(Type type, SourceOp op) const { + IntegerType itype = dyn_cast(type); + if (!itype || itype.getWidth() != 32) + return ""; + return i32Func; + } + + const std::string i32Func; +}; + +/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`, +/// depending on the type of the result. This assumes that the first argument is +/// a floating type and the second argument is an integer type. +/// +/// Example with NVVM: +/// %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32 +/// +/// will be transformed into +/// llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32 +/// +template +struct FloatIntOpToFuncCallLowering + : public OpToFuncCallLoweringBase> { +public: + explicit FloatIntOpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef f32Func, StringRef f64Func) + : OpToFuncCallLoweringBase>( + lowering), + f32Func(f32Func), f64Func(f64Func) {} + + Value maybeCast(Value operand, PatternRewriter &rewriter) const { + return operand; + } + + StringRef getFunctionName(Type type, SourceOp op) const { + if (isa(type)) { + return f32Func; + } + if (isa(type)) + return f64Func; + return ""; } const std::string f32Func; const std::string f64Func; - const std::string f32ApproxFunc; - const std::string f16Func; }; } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 2768929f460e2..1971de30898fb 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -446,6 +446,22 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +template +static void populateIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + StringRef i32Func) { + patterns.add>(converter); + patterns.add>(converter, i32Func); +} + +template +static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + StringRef f32Func, StringRef f64Func) { + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func); +} + void mlir::populateGpuSubgroupReduceOpLoweringPattern( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); @@ -509,6 +525,7 @@ void mlir::populateGpuToNVVMConversionPatterns( populateOpPatterns(converter, patterns, "__nv_fmodf", "__nv_fmod"); + populateIntOpPatterns(converter, patterns, "__nv_abs"); populateOpPatterns(converter, patterns, "__nv_fabsf", "__nv_fabs"); populateOpPatterns(converter, patterns, "__nv_acosf", @@ -555,6 +572,8 @@ void mlir::populateGpuToNVVMConversionPatterns( "__nv_log2", "__nv_fast_log2f"); populateOpPatterns(converter, patterns, "__nv_powf", "__nv_pow", "__nv_fast_powf"); + populateFloatIntOpPatterns(converter, patterns, "__nv_powif", + "__nv_powi"); populateOpPatterns(converter, patterns, "__nv_roundf", "__nv_round"); populateOpPatterns(converter, patterns, "__nv_rintf", diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index f52dd6c0d0ce3..94c0f9e34c29c 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1033,3 +1033,27 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +gpu.module @test_module_52 { + // CHECK: llvm.func @__nv_abs(i32) -> i32 + // CHECK-LABEL: func @gpu_abs + func.func @gpu_abs(%arg_i32 : i32) -> (i32) { + %result32 = math.absi %arg_i32 : i32 + // CHECK: llvm.call @__nv_abs(%{{.*}}) : (i32) -> i32 + func.return %result32 : i32 + } +} + +gpu.module @test_module_53 { + // CHECK: llvm.func @__nv_powif(f32, i32) -> f32 + // CHECK: llvm.func @__nv_powi(f64, i32) -> f64 + // CHECK-LABEL: func @gpu_powi + func.func @gpu_powi(%arg_f32 : f32, %arg_f64 : f64, %arg_i32 : i32) -> (f32, f64) { + %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32 + // CHECK: llvm.call @__nv_powif(%{{.*}}, %{{.*}}) : (f32, i32) -> f32 + %result64 = math.fpowi %arg_f64, %arg_i32 : f64, i32 + // CHECK: llvm.call @__nv_powi(%{{.*}}, %{{.*}}) : (f64, i32) -> f64 + func.return %result32, %result64 : f32, f64 + } +} From d6b89f00447add33943da404fa198d80103c8c47 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 19 Jan 2025 14:19:06 -0600 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Oleksandr "Alex" Zinenko --- mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index bbfcdaf91205c..0c1755d593339 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -47,12 +47,12 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { SmallVector castedOperands; for (Value operand : adaptor.getOperands()) castedOperands.push_back( - ((const DerivedTy *)this)->maybeCast(operand, rewriter)); + static_cast(this)->maybeCast(operand, rewriter)); Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); StringRef funcName = - ((const DerivedTy *)this) + static_cast(this) ->getFunctionName( cast(funcType).getReturnType(), op); if (funcName.empty()) @@ -85,11 +85,13 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { using LLVM::LLVMFuncOp; auto funcAttr = StringAttr::get(op->getContext(), funcName); - Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + auto funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (funcOp) - return cast(*funcOp); + return funcOp; - mlir::OpBuilder b(op->getParentOfType()); + auto parentFunc = op->getParentOfType(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); return b.create(op->getLoc(), funcName, funcType); } }; From cd3148548e2d92ab94edd08e9381131174b4b392 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 20 Jan 2025 17:36:04 +0100 Subject: [PATCH 3/3] replace CRTP with a simple trait in a common flow --- .../GPUCommon/OpToFuncCallLowering.h | 216 ++++++------------ .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 4 +- 2 files changed, 78 insertions(+), 142 deletions(-) diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 0c1755d593339..9f7ceb11752ba 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -16,11 +16,51 @@ namespace mlir { -template -struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { +namespace { +/// Detection trait tor the `getFastmath` instance method. +template +using has_get_fastmath_t = decltype(std::declval().getFastmath()); +} // namespace + +/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or +/// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and +/// the fastMathFlag of that Op, if present. The function declaration is added +/// in case it was not added before. +/// +/// If the input values are of bf16 type (or f16 type if f16Func is empty), the +/// value is first casted to f32, the function called and then the result casted +/// back. +/// +/// Example with NVVM: +/// %exp_f32 = math.exp %arg_f32 : f32 +/// +/// will be transformed into +/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 +/// +/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers +/// to the approximate calculation function. +/// +/// Also example with NVVM: +/// %exp_f32 = math.exp %arg_f32 fastmath : f32 +/// +/// will be transformed into +/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 +/// +/// Final example with NVVM: +/// %pow_f32 = math.fpowi %arg_f32, %arg_i32 +/// +/// will be transformed into +/// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32 +template +struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { public: - explicit OpToFuncCallLoweringBase(const LLVMTypeConverter &lowering) - : ConvertOpToLLVMPattern(lowering) {} + explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef f32Func, StringRef f64Func, + StringRef f32ApproxFunc, StringRef f16Func, + StringRef i32Func = "") + : ConvertOpToLLVMPattern(lowering), f32Func(f32Func), + f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func), + i32Func(i32Func) {} LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, @@ -46,15 +86,12 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { SmallVector castedOperands; for (Value operand : adaptor.getOperands()) - castedOperands.push_back( - static_cast(this)->maybeCast(operand, rewriter)); + castedOperands.push_back(maybeCast(operand, rewriter)); Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); - StringRef funcName = - static_cast(this) - ->getFunctionName( - cast(funcType).getReturnType(), op); + StringRef funcName = getFunctionName( + cast(funcType).getReturnType(), op); if (funcName.empty()) return failure(); @@ -67,6 +104,8 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { return success(); } + assert(callOp.getResult().getType().isF32() && + "only f32 types are supposed to be truncated back"); Value truncated = rewriter.create( op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); @@ -74,7 +113,19 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { return success(); } -private: + Value maybeCast(Value operand, PatternRewriter &rewriter) const { + Type type = operand.getType(); + if (!isa(type)) + return operand; + + // if there's a f16 function, no need to cast f16 values + if (!f16Func.empty() && isa(type)) + return operand; + + return rewriter.create( + operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + } + Type getFunctionType(Type resultType, ValueRange operands) const { SmallVector operandTypes(operands.getTypes()); return LLVM::LLVMFunctionType::get(resultType, operandTypes); @@ -85,7 +136,8 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { using LLVM::LLVMFuncOp; auto funcAttr = StringAttr::get(op->getContext(), funcName); - auto funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + auto funcOp = + SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (funcOp) return funcOp; @@ -94,70 +146,27 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { OpBuilder b(parentFunc); return b.create(op->getLoc(), funcName, funcType); } -}; - -/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or -/// `f32ApproxFunc` or `f16Func` depending on the element type and the -/// fastMathFlag of that Op. The function declaration is added in case it was -/// not added before. -/// -/// If the input values are of bf16 type (or f16 type if f16Func is empty), the -/// value is first casted to f32, the function called and then the result casted -/// back. -/// -/// Example with NVVM: -/// %exp_f32 = math.exp %arg_f32 : f32 -/// -/// will be transformed into -/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 -/// -/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers -/// to the approximate calculation function. -/// -/// Also example with NVVM: -/// %exp_f32 = math.exp %arg_f32 fastmath : f32 -/// -/// will be transformed into -/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 -template -struct OpToFuncCallLowering - : public OpToFuncCallLoweringBase> { -public: - explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, - StringRef f32Func, StringRef f64Func, - StringRef f32ApproxFunc, StringRef f16Func) - : OpToFuncCallLoweringBase>( - lowering), - f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), - f16Func(f16Func) {} - - Value maybeCast(Value operand, PatternRewriter &rewriter) const { - Type type = operand.getType(); - if (!isa(type)) - return operand; - - // if there's a f16 function, no need to cast f16 values - if (!f16Func.empty() && isa(type)) - return operand; - - return rewriter.create( - operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); - } StringRef getFunctionName(Type type, SourceOp op) const { - arith::FastMathFlags flag = op.getFastmath(); + bool useApprox = false; + if constexpr (llvm::is_detected::value) { + arith::FastMathFlags flag = op.getFastmath(); + useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) && + !f32ApproxFunc.empty(); + } + if (isa(type)) return f16Func; if (isa(type)) { - if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) && - !f32ApproxFunc.empty()) + if (useApprox) return f32ApproxFunc; - else - return f32Func; + return f32Func; } if (isa(type)) return f64Func; + + if (type.isInteger(32)) + return i32Func; return ""; } @@ -165,82 +174,9 @@ struct OpToFuncCallLowering const std::string f64Func; const std::string f32ApproxFunc; const std::string f16Func; -}; - -/// Rewriting that replace SourceOp with a CallOp to `i32Func` -/// The function declaration is added in case it was not added before. -/// This assumes that all types integral. -/// -/// Example with NVVM: -/// %abs_i32 = math.iabs %arg_i32 : i32 -/// -/// will be transformed into -/// llvm.call @__nv_abs(%arg_i32) : (i32) -> i32 -/// -template -struct IntOpToFuncCallLowering - : public OpToFuncCallLoweringBase> { -public: - explicit IntOpToFuncCallLowering(const LLVMTypeConverter &lowering, - StringRef i32Func) - : OpToFuncCallLoweringBase>( - lowering), - i32Func(i32Func) {} - - Value maybeCast(Value operand, PatternRewriter &rewriter) const { - return operand; - } - - StringRef getFunctionName(Type type, SourceOp op) const { - IntegerType itype = dyn_cast(type); - if (!itype || itype.getWidth() != 32) - return ""; - return i32Func; - } - const std::string i32Func; }; -/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`, -/// depending on the type of the result. This assumes that the first argument is -/// a floating type and the second argument is an integer type. -/// -/// Example with NVVM: -/// %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32 -/// -/// will be transformed into -/// llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32 -/// -template -struct FloatIntOpToFuncCallLowering - : public OpToFuncCallLoweringBase> { -public: - explicit FloatIntOpToFuncCallLowering(const LLVMTypeConverter &lowering, - StringRef f32Func, StringRef f64Func) - : OpToFuncCallLoweringBase>( - lowering), - f32Func(f32Func), f64Func(f64Func) {} - - Value maybeCast(Value operand, PatternRewriter &rewriter) const { - return operand; - } - - StringRef getFunctionName(Type type, SourceOp op) const { - if (isa(type)) { - return f32Func; - } - if (isa(type)) - return f64Func; - return ""; - } - - const std::string f32Func; - const std::string f64Func; -}; - } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 1971de30898fb..11363a0d60ebf 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -451,7 +451,7 @@ static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef i32Func) { patterns.add>(converter); - patterns.add>(converter, i32Func); + patterns.add>(converter, "", "", "", "", i32Func); } template @@ -459,7 +459,7 @@ static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func) { patterns.add>(converter); - patterns.add>(converter, f32Func, f64Func); + patterns.add>(converter, f32Func, f64Func, "", ""); } void mlir::populateGpuSubgroupReduceOpLoweringPattern(