Skip to content

Commit 9fb66fb

Browse files
mengfei-jianganmyachev
authored andcommitted
[AMD] Preserve Denorms for precise sqrt (#8697)
This commit modifies the denorm behavior for precise sqrt: switching from FTZ (Flush To Zero) to denorm preservation.
1 parent ce3d636 commit 9fb66fb

File tree

2 files changed

+16
-85
lines changed

2 files changed

+16
-85
lines changed

test/Conversion/amd/math-denorm-handling.mlir

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
6464
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
6565
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
6666
tt.func public @test_sqrt_rn_f32(%arg0: tensor<64xf32, #blocked>) {
67-
// LLVM_FTZ-LABEL: test_sqrt_rn_f32
68-
// LLVM_FTZ: llvm.amdgcn.rsq.f32
69-
// LLVM_FTZ: llvm.fmul
70-
// LLVM_FTZ: llvm.fmul
71-
// LLVM_FTZ: llvm.fneg
72-
// LLVM_FTZ: llvm.intr.fma
73-
// LLVM_FTZ-NEXT: llvm.intr.fma
74-
// LLVM_FTZ-NEXT: llvm.intr.fma
75-
// LLVM_FTZ-NEXT: llvm.fneg
76-
// LLVM_FTZ-NEXT: llvm.intr.fma
77-
// LLVM_FTZ-NEXT: llvm.intr.fma
78-
// LLVM_FTZ-NEXT: llvm.intr.is.fpclass
79-
// LLVM_FTZ-NEXT: llvm.select
80-
//
81-
// LLVM_NO_FTZ-LABEL: test_sqrt_rn_f32
82-
// LLVM_NO_FTZ: llvm.intr.sqrt
67+
// COMMON-LABEL: test_sqrt_rn_f32
68+
// COMMON: llvm.intr.sqrt
8369
%0 = tt.precise_sqrt %arg0 : tensor<64xf32, #blocked>
8470
tt.return
8571
}
@@ -96,3 +82,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9682
tt.return
9783
}
9884
}
85+
86+
// -----
87+
88+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
89+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
90+
tt.func public @test_divf_rn_f32(%arg0: tensor<64xf32, #blocked>, %arg1: tensor<64xf32, #blocked>) {
91+
// COMMON-LABEL: test_divf_rn_f32
92+
// COMMON: llvm.fdiv
93+
%0 = tt.precise_divf %arg0, %arg1 : tensor<64xf32, #blocked>
94+
tt.return
95+
}
96+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,73 +2264,6 @@ struct SqrtOpConversion
22642264
}
22652265
}
22662266

2267-
private:
2268-
bool ftz;
2269-
};
2270-
2271-
struct PreciseSqrtOpConversion
2272-
: ElementwiseOpConversionBase<triton::PreciseSqrtOp,
2273-
PreciseSqrtOpConversion> {
2274-
explicit PreciseSqrtOpConversion(LLVMTypeConverter &typeConverter,
2275-
ModuleAxisInfoAnalysis &axisInfoAnalysis,
2276-
bool ftz, PatternBenefit benefit)
2277-
: ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit),
2278-
ftz(ftz) {}
2279-
2280-
SmallVector<Value> createDestOps(triton::PreciseSqrtOp op, OpAdaptor adaptor,
2281-
ConversionPatternRewriter &rewriter,
2282-
Type elemTy, MultipleOperandsRange operands,
2283-
Location loc) const {
2284-
auto b = TritonLLVMOpBuilder(loc, rewriter);
2285-
// If the op is neither FP32 nor denorm flushing(ftz), it's directly lowered
2286-
// to LLVM::SqrtOp.
2287-
if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) {
2288-
return {LLVM::SqrtOp::create(rewriter, loc, elemTy, operands[0],
2289-
adaptor.getAttributes().getValue())};
2290-
}
2291-
2292-
// On the AMDGPU backend, instructions legalized from LLVM::SqrtOp are
2293-
// designed to always preserve denorms, according to
2294-
// https://github.com/llvm/llvm-project/blob/3d6b2d49/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L5235-L5314.
2295-
//
2296-
// For f32 inputs with ftz enabled, we need to manually lower the op to
2297-
// bypass the scaling-up-and-down process while keeping other parts
2298-
// unchanged. To ensure IEEE-compliant results, we approximate `sqrt(x)`
2299-
// using `x * rsq(x)` and apply extra refinement iterations to correct the
2300-
// result.
2301-
StringRef funcName = "llvm.amdgcn.rsq.f32";
2302-
2303-
Type funcType = getFunctionType(elemTy, operands[0]);
2304-
LLVM::LLVMFuncOp funcOp =
2305-
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
2306-
2307-
Value sqrtR =
2308-
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult();
2309-
2310-
Value sqrtX = operands[0][0];
2311-
Value sqrtS = b.fmul(f32_ty, sqrtX, sqrtR);
2312-
2313-
// Refine the approximation with Newton iteration
2314-
Value sqrtH = b.fmul(f32_ty, sqrtR, b.f32_val(0.5f));
2315-
Value sqrtE = b.fma(b.neg(f32_ty, sqrtH), sqrtS, b.f32_val(0.5f));
2316-
sqrtH = b.fma(sqrtH, sqrtE, sqrtH);
2317-
sqrtS = b.fma(sqrtS, sqrtE, sqrtS);
2318-
Value sqrtD = b.fma(b.neg(f32_ty, sqrtS), sqrtS, sqrtX);
2319-
sqrtS = b.fma(sqrtD, sqrtH, sqrtS);
2320-
2321-
// Handle +0/-0/+inf
2322-
// These flags come from
2323-
// https://github.com/llvm/llvm-project/blob/217e0f39/llvm/include/llvm/ADT/FloatingPointMode.h#L239-L265.
2324-
const unsigned fcPosInf = 0x0200;
2325-
const unsigned fcNegZero = 0x0020;
2326-
const unsigned fcPosZero = 0x0040;
2327-
const unsigned fcZero = fcNegZero | fcPosZero;
2328-
2329-
Value isZeroOrPosInf =
2330-
LLVM::IsFPClass::create(rewriter, loc, i1_ty, sqrtX, fcPosInf | fcZero);
2331-
return {b.select(isZeroOrPosInf, sqrtX, sqrtS)};
2332-
}
2333-
23342267
private:
23352268
bool ftz;
23362269
};
@@ -2382,6 +2315,8 @@ void populateElementwiseOpToLLVMPatterns(
23822315
typeConverter, axisInfoAnalysis, benefit);
23832316
patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
23842317
typeConverter, axisInfoAnalysis, benefit);
2318+
patterns.add<ElementwiseOpConversion<triton::PreciseSqrtOp, LLVM::SqrtOp>>(
2319+
typeConverter, axisInfoAnalysis, benefit);
23852320

23862321
patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
23872322
patterns.add<FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit);
@@ -2409,8 +2344,6 @@ void populateElementwiseOpToLLVMPatterns(
24092344
patterns.add<RsqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz,
24102345
benefit);
24112346
patterns.add<SqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz, benefit);
2412-
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz,
2413-
benefit);
24142347
triton::populateElementwiseOpToLLVMPatterns(
24152348
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
24162349
bool hwNanPropagationSupported = targetInfo.supportMaximumMinimum();

0 commit comments

Comments
 (0)