diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 46fd182346b3b..0c1755d593339 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -16,37 +16,11 @@ namespace mlir { -/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or -/// `f32ApproxFunc` or `f16Func` depending on the element type and the -/// fastMathFlag of that Op. The function declaration is added in case it was -/// not added before. -/// -/// If the input values are of bf16 type (or f16 type if f16Func is empty), the -/// value is first casted to f32, the function called and then the result casted -/// back. -/// -/// Example with NVVM: -/// %exp_f32 = math.exp %arg_f32 : f32 -/// -/// will be transformed into -/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 -/// -/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers -/// to the approximate calculation function. -/// -/// Also example with NVVM: -/// %exp_f32 = math.exp %arg_f32 fastmath : f32 -/// -/// will be transformed into -/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 -template -struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { +template +struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern { public: - explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, - StringRef f32Func, StringRef f64Func, - StringRef f32ApproxFunc, StringRef f16Func) - : ConvertOpToLLVMPattern(lowering), f32Func(f32Func), - f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {} + explicit OpToFuncCallLoweringBase(const LLVMTypeConverter &lowering) + : ConvertOpToLLVMPattern(lowering) {} LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, @@ -72,13 +46,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { SmallVector castedOperands; for (Value operand : adaptor.getOperands()) - castedOperands.push_back(maybeCast(operand, rewriter)); + castedOperands.push_back( + static_cast(this)->maybeCast(operand, rewriter)); Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); StringRef funcName = - getFunctionName(cast(funcType).getReturnType(), - op.getFastmath()); + static_cast(this) + ->getFunctionName( + cast(funcType).getReturnType(), op); if (funcName.empty()) return failure(); @@ -99,6 +75,63 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { } private: + Type getFunctionType(Type resultType, ValueRange operands) const { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); + } + + LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, + Operation *op) const { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + auto funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return funcOp; + + auto parentFunc = op->getParentOfType(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); + return b.create(op->getLoc(), funcName, funcType); + } +}; + +/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or +/// `f32ApproxFunc` or `f16Func` depending on the element type and the +/// fastMathFlag of that Op. The function declaration is added in case it was +/// not added before. +/// +/// If the input values are of bf16 type (or f16 type if f16Func is empty), the +/// value is first casted to f32, the function called and then the result casted +/// back. +/// +/// Example with NVVM: +/// %exp_f32 = math.exp %arg_f32 : f32 +/// +/// will be transformed into +/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 +/// +/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers +/// to the approximate calculation function. +/// +/// Also example with NVVM: +/// %exp_f32 = math.exp %arg_f32 fastmath : f32 +/// +/// will be transformed into +/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 +template +struct OpToFuncCallLowering + : public OpToFuncCallLoweringBase> { +public: + explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef f32Func, StringRef f64Func, + StringRef f32ApproxFunc, StringRef f16Func) + : OpToFuncCallLoweringBase>( + lowering), + f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), + f16Func(f16Func) {} + Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); if (!isa(type)) @@ -112,12 +145,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); } - Type getFunctionType(Type resultType, ValueRange operands) const { - SmallVector operandTypes(operands.getTypes()); - return LLVM::LLVMFunctionType::get(resultType, operandTypes); - } - - StringRef getFunctionName(Type type, arith::FastMathFlags flag) const { + StringRef getFunctionName(Type type, SourceOp op) const { + arith::FastMathFlags flag = op.getFastmath(); if (isa(type)) return f16Func; if (isa(type)) { @@ -132,23 +161,84 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { return ""; } - LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, - Operation *op) const { - using LLVM::LLVMFuncOp; + const std::string f32Func; + const std::string f64Func; + const std::string f32ApproxFunc; + const std::string f16Func; +}; - auto funcAttr = StringAttr::get(op->getContext(), funcName); - Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); - if (funcOp) - return cast(*funcOp); +/// Rewriting that replace SourceOp with a CallOp to `i32Func` +/// The function declaration is added in case it was not added before. +/// This assumes that all types integral. +/// +/// Example with NVVM: +/// %abs_i32 = math.iabs %arg_i32 : i32 +/// +/// will be transformed into +/// llvm.call @__nv_abs(%arg_i32) : (i32) -> i32 +/// +template +struct IntOpToFuncCallLowering + : public OpToFuncCallLoweringBase> { +public: + explicit IntOpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef i32Func) + : OpToFuncCallLoweringBase>( + lowering), + i32Func(i32Func) {} - mlir::OpBuilder b(op->getParentOfType()); - return b.create(op->getLoc(), funcName, funcType); + Value maybeCast(Value operand, PatternRewriter &rewriter) const { + return operand; + } + + StringRef getFunctionName(Type type, SourceOp op) const { + IntegerType itype = dyn_cast(type); + if (!itype || itype.getWidth() != 32) + return ""; + return i32Func; + } + + const std::string i32Func; +}; + +/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`, +/// depending on the type of the result. This assumes that the first argument is +/// a floating type and the second argument is an integer type. +/// +/// Example with NVVM: +/// %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32 +/// +/// will be transformed into +/// llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32 +/// +template +struct FloatIntOpToFuncCallLowering + : public OpToFuncCallLoweringBase> { +public: + explicit FloatIntOpToFuncCallLowering(const LLVMTypeConverter &lowering, + StringRef f32Func, StringRef f64Func) + : OpToFuncCallLoweringBase>( + lowering), + f32Func(f32Func), f64Func(f64Func) {} + + Value maybeCast(Value operand, PatternRewriter &rewriter) const { + return operand; + } + + StringRef getFunctionName(Type type, SourceOp op) const { + if (isa(type)) { + return f32Func; + } + if (isa(type)) + return f64Func; + return ""; } const std::string f32Func; const std::string f64Func; - const std::string f32ApproxFunc; - const std::string f16Func; }; } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 2768929f460e2..1971de30898fb 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -446,6 +446,22 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +template +static void populateIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + StringRef i32Func) { + patterns.add>(converter); + patterns.add>(converter, i32Func); +} + +template +static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + StringRef f32Func, StringRef f64Func) { + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func); +} + void mlir::populateGpuSubgroupReduceOpLoweringPattern( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); @@ -509,6 +525,7 @@ void mlir::populateGpuToNVVMConversionPatterns( populateOpPatterns(converter, patterns, "__nv_fmodf", "__nv_fmod"); + populateIntOpPatterns(converter, patterns, "__nv_abs"); populateOpPatterns(converter, patterns, "__nv_fabsf", "__nv_fabs"); populateOpPatterns(converter, patterns, "__nv_acosf", @@ -555,6 +572,8 @@ void mlir::populateGpuToNVVMConversionPatterns( "__nv_log2", "__nv_fast_log2f"); populateOpPatterns(converter, patterns, "__nv_powf", "__nv_pow", "__nv_fast_powf"); + populateFloatIntOpPatterns(converter, patterns, "__nv_powif", + "__nv_powi"); populateOpPatterns(converter, patterns, "__nv_roundf", "__nv_round"); populateOpPatterns(converter, patterns, "__nv_rintf", diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index f52dd6c0d0ce3..94c0f9e34c29c 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1033,3 +1033,27 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +gpu.module @test_module_52 { + // CHECK: llvm.func @__nv_abs(i32) -> i32 + // CHECK-LABEL: func @gpu_abs + func.func @gpu_abs(%arg_i32 : i32) -> (i32) { + %result32 = math.absi %arg_i32 : i32 + // CHECK: llvm.call @__nv_abs(%{{.*}}) : (i32) -> i32 + func.return %result32 : i32 + } +} + +gpu.module @test_module_53 { + // CHECK: llvm.func @__nv_powif(f32, i32) -> f32 + // CHECK: llvm.func @__nv_powi(f64, i32) -> f64 + // CHECK-LABEL: func @gpu_powi + func.func @gpu_powi(%arg_f32 : f32, %arg_f64 : f64, %arg_i32 : i32) -> (f32, f64) { + %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32 + // CHECK: llvm.call @__nv_powif(%{{.*}}, %{{.*}}) : (f32, i32) -> f32 + %result64 = math.fpowi %arg_f64, %arg_i32 : f64, i32 + // CHECK: llvm.call @__nv_powi(%{{.*}}, %{{.*}}) : (f64, i32) -> f64 + func.return %result32, %result64 : f32, f64 + } +}