@@ -2264,73 +2264,6 @@ struct SqrtOpConversion
22642264 }
22652265 }
22662266
2267- private:
2268- bool ftz;
2269- };
2270-
2271- struct PreciseSqrtOpConversion
2272- : ElementwiseOpConversionBase<triton::PreciseSqrtOp,
2273- PreciseSqrtOpConversion> {
2274- explicit PreciseSqrtOpConversion (LLVMTypeConverter &typeConverter,
2275- ModuleAxisInfoAnalysis &axisInfoAnalysis,
2276- bool ftz, PatternBenefit benefit)
2277- : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit),
2278- ftz(ftz) {}
2279-
2280- SmallVector<Value> createDestOps (triton::PreciseSqrtOp op, OpAdaptor adaptor,
2281- ConversionPatternRewriter &rewriter,
2282- Type elemTy, MultipleOperandsRange operands,
2283- Location loc) const {
2284- auto b = TritonLLVMOpBuilder (loc, rewriter);
2285- // If the op is neither FP32 nor denorm flushing(ftz), it's directly lowered
2286- // to LLVM::SqrtOp.
2287- if (elemTy.getIntOrFloatBitWidth () != 32 || !ftz) {
2288- return {LLVM::SqrtOp::create (rewriter, loc, elemTy, operands[0 ],
2289- adaptor.getAttributes ().getValue ())};
2290- }
2291-
2292- // On the AMDGPU backend, instructions legalized from LLVM::SqrtOp are
2293- // designed to always preserve denorms, according to
2294- // https://github.com/llvm/llvm-project/blob/3d6b2d49/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L5235-L5314.
2295- //
2296- // For f32 inputs with ftz enabled, we need to manually lower the op to
2297- // bypass the scaling-up-and-down process while keeping other parts
2298- // unchanged. To ensure IEEE-compliant results, we approximate `sqrt(x)`
2299- // using `x * rsq(x)` and apply extra refinement iterations to correct the
2300- // result.
2301- StringRef funcName = " llvm.amdgcn.rsq.f32" ;
2302-
2303- Type funcType = getFunctionType (elemTy, operands[0 ]);
2304- LLVM::LLVMFuncOp funcOp =
2305- appendOrGetExternFuncOp (rewriter, op, funcName, funcType);
2306-
2307- Value sqrtR =
2308- LLVM::createLLVMCallOp (rewriter, loc, funcOp, operands[0 ]).getResult ();
2309-
2310- Value sqrtX = operands[0 ][0 ];
2311- Value sqrtS = b.fmul (f32_ty, sqrtX, sqrtR);
2312-
2313- // Refine the approximation with Newton iteration
2314- Value sqrtH = b.fmul (f32_ty, sqrtR, b.f32_val (0 .5f ));
2315- Value sqrtE = b.fma (b.neg (f32_ty, sqrtH), sqrtS, b.f32_val (0 .5f ));
2316- sqrtH = b.fma (sqrtH, sqrtE, sqrtH);
2317- sqrtS = b.fma (sqrtS, sqrtE, sqrtS);
2318- Value sqrtD = b.fma (b.neg (f32_ty, sqrtS), sqrtS, sqrtX);
2319- sqrtS = b.fma (sqrtD, sqrtH, sqrtS);
2320-
2321- // Handle +0/-0/+inf
2322- // These flags come from
2323- // https://github.com/llvm/llvm-project/blob/217e0f39/llvm/include/llvm/ADT/FloatingPointMode.h#L239-L265.
2324- const unsigned fcPosInf = 0x0200 ;
2325- const unsigned fcNegZero = 0x0020 ;
2326- const unsigned fcPosZero = 0x0040 ;
2327- const unsigned fcZero = fcNegZero | fcPosZero;
2328-
2329- Value isZeroOrPosInf =
2330- LLVM::IsFPClass::create (rewriter, loc, i1_ty, sqrtX, fcPosInf | fcZero);
2331- return {b.select (isZeroOrPosInf, sqrtX, sqrtS)};
2332- }
2333-
23342267private:
23352268 bool ftz;
23362269};
@@ -2382,6 +2315,8 @@ void populateElementwiseOpToLLVMPatterns(
23822315 typeConverter, axisInfoAnalysis, benefit);
23832316 patterns.add <ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
23842317 typeConverter, axisInfoAnalysis, benefit);
2318+ patterns.add <ElementwiseOpConversion<triton::PreciseSqrtOp, LLVM::SqrtOp>>(
2319+ typeConverter, axisInfoAnalysis, benefit);
23852320
23862321 patterns.add <FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
23872322 patterns.add <FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit);
@@ -2409,8 +2344,6 @@ void populateElementwiseOpToLLVMPatterns(
24092344 patterns.add <RsqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz,
24102345 benefit);
24112346 patterns.add <SqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz, benefit);
2412- patterns.add <PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz,
2413- benefit);
24142347 triton::populateElementwiseOpToLLVMPatterns (
24152348 typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
24162349 bool hwNanPropagationSupported = targetInfo.supportMaximumMinimum ();
0 commit comments