|
4 | 4 | #include "mlir/IR/MLIRContext.h" |
5 | 5 | #include "third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h" |
6 | 6 | #include "third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" |
| 7 | +#include "third_party/intel/lib/Utils/Mangling.h" |
7 | 8 | #include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" |
8 | 9 | #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" |
9 | 10 | #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" |
@@ -1135,22 +1136,41 @@ struct AbsFOpConversion |
1135 | 1136 | } |
1136 | 1137 | }; |
1137 | 1138 |
|
1138 | | -struct PreciseSqrtOpConversion |
1139 | | - : ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> { |
| 1139 | +template <typename TritonOp> |
| 1140 | +struct OpToExternCallConversion |
| 1141 | + : public ElementwiseOpConversionBase<TritonOp, |
| 1142 | + OpToExternCallConversion<TritonOp>> { |
1140 | 1143 | using Base = |
1141 | | - ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion>; |
| 1144 | + ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>; |
1142 | 1145 | using Base::Base; |
1143 | 1146 | using Adaptor = typename Base::OpAdaptor; |
1144 | 1147 |
|
1145 | | - SmallVector<Value> createDestOps(PreciseSqrtOp op, Adaptor adaptor, |
| 1148 | + explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter, |
| 1149 | + ModuleAxisInfoAnalysis &axisAnalysisPass, |
| 1150 | + StringRef externFuncName, |
| 1151 | + PatternBenefit benefit) |
| 1152 | + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, |
| 1153 | + benefit), |
| 1154 | + funcName(externFuncName) {} |
| 1155 | + |
| 1156 | + SmallVector<Value> createDestOps(TritonOp op, Adaptor adaptor, |
1146 | 1157 | ConversionPatternRewriter &rewriter, |
1147 | 1158 | Type elemTy, MultipleOperandsRange operands, |
1148 | 1159 | Location loc) const { |
1149 | | - // FIXME: Use precise sqrt builtin: #5419 |
1150 | | - // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt. |
1151 | | - return {rewriter.create<LLVM::SqrtOp>(loc, elemTy, operands[0], |
1152 | | - adaptor.getAttributes().getValue())}; |
| 1160 | + Type funcType = getFunctionType(elemTy, operands[0]); |
| 1161 | + SmallVector<Type> operandTypes(ValueRange(operands[0]).getTypes()); |
| 1162 | + std::string fnName = |
| 1163 | + mlir::triton::gpu::intel::mangle(funcName, operandTypes); |
| 1164 | + LLVM::LLVMFuncOp funcOp = |
| 1165 | + appendOrGetExternFuncOp(rewriter, op, fnName, funcType); |
| 1166 | + funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op)); |
| 1167 | + auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]); |
| 1168 | + callOp.setCConv(funcOp.getCConv()); |
| 1169 | + return {callOp.getResult()}; |
1153 | 1170 | } |
| 1171 | + |
| 1172 | +private: |
| 1173 | + StringRef funcName; |
1154 | 1174 | }; |
1155 | 1175 |
|
1156 | 1176 | // Following two patterns are copied from the common part to fix-up calling |
@@ -1225,12 +1245,10 @@ void populateElementwiseOpToLLVMPatterns( |
1225 | 1245 | ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, |
1226 | 1246 | PatternBenefit benefit) { |
1227 | 1247 |
|
1228 | | - patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis, |
1229 | | - benefit); |
1230 | | - // FIXME: Use precise divide builtin: #5419 |
1231 | | - // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide. |
1232 | | - patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>( |
1233 | | - typeConverter, axisInfoAnalysis, benefit); |
| 1248 | + patterns.add<OpToExternCallConversion<triton::PreciseSqrtOp>>( |
| 1249 | + typeConverter, axisInfoAnalysis, "sqrt_cr", benefit); |
| 1250 | + patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>( |
| 1251 | + typeConverter, axisInfoAnalysis, "divide_cr", benefit); |
1234 | 1252 | patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo, |
1235 | 1253 | benefit); |
1236 | 1254 | patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis, |
|
0 commit comments