diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 46fd182346b3b..9f7ceb11752ba 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -16,10 +16,16 @@ namespace mlir { -/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or -/// `f32ApproxFunc` or `f16Func` depending on the element type and the -/// fastMathFlag of that Op. The function declaration is added in case it was -/// not added before. +namespace { +/// Detection trait tor the `getFastmath` instance method. +template +using has_get_fastmath_t = decltype(std::declval().getFastmath()); +} // namespace + +/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or +/// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and +/// the fastMathFlag of that Op, if present. The function declaration is added +/// in case it was not added before. /// /// If the input values are of bf16 type (or f16 type if f16Func is empty), the /// value is first casted to f32, the function called and then the result casted @@ -39,14 +45,22 @@ namespace mlir { /// /// will be transformed into /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 +/// +/// Final example with NVVM: +/// %pow_f32 = math.fpowi %arg_f32, %arg_i32 +/// +/// will be transformed into +/// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32 template struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { public: explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, - StringRef f32ApproxFunc, StringRef f16Func) + StringRef f32ApproxFunc, StringRef f16Func, + StringRef i32Func = "") : ConvertOpToLLVMPattern(lowering), f32Func(f32Func), - f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {} + f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func), + i32Func(i32Func) {} LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, @@ -76,9 +90,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); - StringRef funcName = - getFunctionName(cast(funcType).getReturnType(), - op.getFastmath()); + StringRef funcName = getFunctionName( + cast(funcType).getReturnType(), op); if (funcName.empty()) return failure(); @@ -91,6 +104,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { return success(); } + assert(callOp.getResult().getType().isF32() && + "only f32 types are supposed to be truncated back"); Value truncated = rewriter.create( op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); @@ -98,7 +113,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { return success(); } -private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); if (!isa(type)) @@ -117,38 +131,50 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { return LLVM::LLVMFunctionType::get(resultType, operandTypes); } - StringRef getFunctionName(Type type, arith::FastMathFlags flag) const { - if (isa(type)) - return f16Func; - if (isa(type)) { - if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) && - !f32ApproxFunc.empty()) - return f32ApproxFunc; - else - return f32Func; - } - if (isa(type)) - return f64Func; - return ""; - } - LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const { using LLVM::LLVMFuncOp; auto funcAttr = StringAttr::get(op->getContext(), funcName); - Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + auto funcOp = + SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (funcOp) - return cast(*funcOp); + return funcOp; - mlir::OpBuilder b(op->getParentOfType()); + auto parentFunc = op->getParentOfType(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); return b.create(op->getLoc(), funcName, funcType); } + StringRef getFunctionName(Type type, SourceOp op) const { + bool useApprox = false; + if constexpr (llvm::is_detected::value) { + arith::FastMathFlags flag = op.getFastmath(); + useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) && + !f32ApproxFunc.empty(); + } + + if (isa(type)) + return f16Func; + if (isa(type)) { + if (useApprox) + return f32ApproxFunc; + return f32Func; + } + if (isa(type)) + return f64Func; + + if (type.isInteger(32)) + return i32Func; + return ""; + } + const std::string f32Func; const std::string f64Func; const std::string f32ApproxFunc; const std::string f16Func; + const std::string i32Func; }; } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 2768929f460e2..11363a0d60ebf 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -446,6 +446,22 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +template +static void populateIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + StringRef i32Func) { + patterns.add>(converter); + patterns.add>(converter, "", "", "", "", i32Func); +} + +template +static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + StringRef f32Func, StringRef f64Func) { + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func, "", ""); +} + void mlir::populateGpuSubgroupReduceOpLoweringPattern( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); @@ -509,6 +525,7 @@ void mlir::populateGpuToNVVMConversionPatterns( populateOpPatterns(converter, patterns, "__nv_fmodf", "__nv_fmod"); + populateIntOpPatterns(converter, patterns, "__nv_abs"); populateOpPatterns(converter, patterns, "__nv_fabsf", "__nv_fabs"); populateOpPatterns(converter, patterns, "__nv_acosf", @@ -555,6 +572,8 @@ void mlir::populateGpuToNVVMConversionPatterns( "__nv_log2", "__nv_fast_log2f"); populateOpPatterns(converter, patterns, "__nv_powf", "__nv_pow", "__nv_fast_powf"); + populateFloatIntOpPatterns(converter, patterns, "__nv_powif", + "__nv_powi"); populateOpPatterns(converter, patterns, "__nv_roundf", "__nv_round"); populateOpPatterns(converter, patterns, "__nv_rintf", diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index f52dd6c0d0ce3..94c0f9e34c29c 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1033,3 +1033,27 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +gpu.module @test_module_52 { + // CHECK: llvm.func @__nv_abs(i32) -> i32 + // CHECK-LABEL: func @gpu_abs + func.func @gpu_abs(%arg_i32 : i32) -> (i32) { + %result32 = math.absi %arg_i32 : i32 + // CHECK: llvm.call @__nv_abs(%{{.*}}) : (i32) -> i32 + func.return %result32 : i32 + } +} + +gpu.module @test_module_53 { + // CHECK: llvm.func @__nv_powif(f32, i32) -> f32 + // CHECK: llvm.func @__nv_powi(f64, i32) -> f64 + // CHECK-LABEL: func @gpu_powi + func.func @gpu_powi(%arg_f32 : f32, %arg_f64 : f64, %arg_i32 : i32) -> (f32, f64) { + %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32 + // CHECK: llvm.call @__nv_powif(%{{.*}}, %{{.*}}) : (f32, i32) -> f32 + %result64 = math.fpowi %arg_f64, %arg_i32 : f64, i32 + // CHECK: llvm.call @__nv_powi(%{{.*}}, %{{.*}}) : (f64, i32) -> f64 + func.return %result32, %result64 : f32, f64 + } +}