@@ -970,8 +970,8 @@ struct ElementwiseOpConversion
970970 ConversionPatternRewriter &rewriter,
971971 Type elemTy, MultipleOperandsRange operands,
972972 Location loc) const {
973- assert ((!getElementType (op. getLhs () ).isBF16 () &&
974- !getElementType (op. getRhs () ).isBF16 ()) &&
973+ assert ((!getElementType (operands[ 0 ][ 0 ] ).isBF16 () &&
974+ !getElementType (operands[ 0 ][ 1 ] ).isBF16 ()) &&
975975 " unsupported conversion" );
976976 return {
977977 rewriter.create <DestOp>(loc, elemTy, operands[0 ][0 ], operands[0 ][1 ])};
@@ -1146,57 +1146,11 @@ struct PreciseSqrtOpConversion
11461146 ConversionPatternRewriter &rewriter,
11471147 Type elemTy, MultipleOperandsRange operands,
11481148 Location loc) const {
1149- auto b = TritonLLVMOpBuilder (loc, rewriter);
1150- Value input = operands[0 ][0 ];
1151- Type origTy = input.getType ();
1152- if (!origTy.isF64 ())
1153- input = b.fpext (f64_ty, input);
1154- Type funcType = LLVM::LLVMFunctionType::get (f64_ty, {f64_ty});
1155- LLVM::LLVMFuncOp funcOp =
1156- appendOrGetExternFuncOp (rewriter, op, " __imf_sqrt_rn" , funcType);
1157- funcOp.setCConv (triton::gpu::intel::getDefaultCConv (op));
1158- LLVM::CallOp callOp =
1159- LLVM::createLLVMCallOp (rewriter, loc, funcOp, {input});
1160- callOp.setCConv (funcOp.getCConv ());
1161- Value result = callOp.getResult ();
1162- if (!origTy.isF64 ())
1163- result = rewriter.create <LLVM::FPTruncOp>(loc, origTy, result);
1164- return {result};
1165- }
1166- };
1167-
1168- template <typename TritonOp>
1169- struct OpToExternCallConversion
1170- : public ElementwiseOpConversionBase<TritonOp,
1171- OpToExternCallConversion<TritonOp>> {
1172- using Base =
1173- ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
1174- using Base::Base;
1175- using Adaptor = typename Base::OpAdaptor;
1176-
1177- explicit OpToExternCallConversion (LLVMTypeConverter &typeConverter,
1178- ModuleAxisInfoAnalysis &axisAnalysisPass,
1179- StringRef externFuncName,
1180- PatternBenefit benefit)
1181- : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
1182- benefit),
1183- funcName(externFuncName) {}
1184-
1185- SmallVector<Value> createDestOps (TritonOp op, Adaptor adaptor,
1186- ConversionPatternRewriter &rewriter,
1187- Type elemTy, MultipleOperandsRange operands,
1188- Location loc) const {
1189- Type funcType = getFunctionType (elemTy, operands[0 ]);
1190- LLVM::LLVMFuncOp funcOp =
1191- appendOrGetExternFuncOp (rewriter, op, funcName, funcType);
1192- funcOp.setCConv (triton::gpu::intel::getDefaultCConv (op));
1193- auto callOp = LLVM::createLLVMCallOp (rewriter, loc, funcOp, operands[0 ]);
1194- callOp.setCConv (funcOp.getCConv ());
1195- return {callOp.getResult ()};
1149+ // FIXME: Use precise sqrt builtin: #5419
1150+ // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
1151+ return {rewriter.create <LLVM::SqrtOp>(loc, elemTy, operands[0 ],
1152+ adaptor.getAttributes ().getValue ())};
11961153 }
1197-
1198- private:
1199- StringRef funcName;
12001154};
12011155
12021156// Following two patterns are copied from the common part to fix-up calling
@@ -1273,8 +1227,10 @@ void populateElementwiseOpToLLVMPatterns(
12731227
12741228 patterns.add <PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
12751229 benefit);
1276- patterns.add <OpToExternCallConversion<triton::PreciseDivFOp>>(
1277- typeConverter, axisInfoAnalysis, " __imf_fdiv_rn" , benefit);
1230+ // FIXME: Use precise divide builtin: #5419
1231+ // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
1232+ patterns.add <ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
1233+ typeConverter, axisInfoAnalysis, benefit);
12781234 patterns.add <MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
12791235 benefit);
12801236 patterns.add <ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
0 commit comments