@@ -1420,6 +1420,33 @@ struct MulhiUIOpConversion
14201420 const TargetInfoBase &targetInfo;
14211421};
14221422
1423+ struct PreciseSqrtOpConversion
1424+ : ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> {
1425+ using Base =
1426+ ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion>;
1427+ using Base::Base;
1428+ using Adaptor = typename Base::OpAdaptor;
1429+
1430+ SmallVector<Value> createDestOps (PreciseSqrtOp op, Adaptor adaptor,
1431+ ConversionPatternRewriter &rewriter,
1432+ Type elemTy, MultipleOperandsRange operands,
1433+ Location loc) const {
1434+ auto input = operands[0 ][0 ];
1435+ auto origTy = input.getType ();
1436+ if (!origTy.isF64 ())
1437+ input = fpext (f64_ty, input);
1438+ Type funcType = LLVM::LLVMFunctionType::get (f64_ty, {f64_ty});
1439+ LLVM::LLVMFuncOp funcOp =
1440+ appendOrGetExternFuncOp (rewriter, op, " __imf_sqrt_rn" , funcType);
1441+ auto callOp = LLVM::createLLVMCallOp (rewriter, loc, funcOp, {input});
1442+ callOp.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
1443+ auto result = callOp.getResult ();
1444+ if (!origTy.isF64 ())
1445+ result = rewriter.create <LLVM::FPTruncOp>(loc, origTy, result);
1446+ return {result};
1447+ }
1448+ };
1449+
14231450template <typename TritonOp>
14241451struct OpToExternCallConversion
14251452 : public ElementwiseOpConversionBase<TritonOp,
@@ -1462,8 +1489,8 @@ void populateElementwiseOpToLLVMPatterns(
14621489 PatternBenefit benefit) {
14631490 using namespace mlir ::triton::gpu;
14641491
1465- patterns.add <OpToExternCallConversion<triton::PreciseSqrtOp>>(
1466- typeConverter, axisInfoAnalysis, " __imf_sqrtf " , benefit);
1492+ patterns.add <PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
1493+ benefit);
14671494 patterns.add <OpToExternCallConversion<triton::PreciseDivFOp>>(
14681495 typeConverter, axisInfoAnalysis, " __imf_fdiv_rn" , benefit);
14691496
0 commit comments