Skip to content

Commit 0e868f8

Browse files
authored
Revert "[AMD] Improve math.fdiv FTZ lowering for f32 inputs" (#7163)
This reverts commit bde92ef. PyTorch issue: pytorch/pytorch#154215 This PR is causing numerics issues upstream in PyTorch, minimal repro below (passes on triton 3.2 but fails on main): ```python # AOT ID: ['0_inference'] from ctypes import c_void_p, c_long, c_int import torch import math import random import os import tempfile from math import inf, nan from cmath import nanj from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.async_compile import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall import triton import triton.language as tl from torch._inductor.runtime.triton_heuristics import start_graph, end_graph from torch._C import _cuda_getCurrentRawStream as get_raw_stream from torch._C import _cuda_getCurrentRawStream as get_raw_stream aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride assert_alignment = torch._C._dynamo.guards.assert_alignment empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p # kernel path: /tmp/tmpwrci5onu/wd/cwdtyp47mdxa7gqc4clytt4lh4mwprgdsb73bx3jld5vkmv2gu6d.py # Topologically Sorted Source Nodes: [upsample_nearest1d], Original ATen: [aten._unsafe_index] # Source node to ATen node mapping: # upsample_nearest1d => _unsafe_index # Graph fragment: # %_unsafe_index : [num_users=1] = call_function[target=torch.ops.aten._unsafe_index.Tensor](args = (%arg3_1, [None, None, %convert_element_type_1]), kwargs = {}) import triton import triton.language as tl from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties triton_helpers.set_driver_to_gpu() @triton.jit def triton_poi_fused__unsafe_index_0(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = (xindex % 74) x1 = xindex // 74 x2 = xindex tmp0 = ks0 / 74 tmp1 = tmp0.to(tl.float32) tmp2 = x0 tmp3 = tmp2.to(tl.float32) tmp4 = tmp3 * tmp1 tmp5 = tmp4.to(tl.int64) tmp6 = tl.load(in_ptr0 + (tmp5 + ks0*x1), xmask, eviction_policy='evict_last') tl.store(out_ptr0 + (x2), tmp6, xmask) def call(args): arg0_1, arg1_1, arg2_1, arg3_1 = args args.clear() s86 = arg0_1 s38 = arg1_1 s32 = arg2_1 assert_size_stride(arg3_1, (s86, s38, s32), (s32*s38, s32, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((s86, s38, 74), (74*s38, 74, 1), torch.float32) # Topologically Sorted Source Nodes: [upsample_nearest1d], Original ATen: [aten._unsafe_index] triton_poi_fused__unsafe_index_0_xnumel = 74*s38*s86 stream0 = get_raw_stream(0) triton_poi_fused__unsafe_index_0[(3, 1, 1)](arg3_1, buf0, 37, 592, 256, num_warps=1, num_stages=1) del arg3_1 return (buf0, ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = 2 arg1_1 = 4 arg2_1 = 37 arg3_1 = rand_strided((2, 4, 37), (148, 37, 1), device='cuda:0', dtype=torch.float32) fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) triton = call([arg0_1, arg1_1, arg2_1, arg3_1])[0] ref = (aten.upsample_nearest1d(arg3_1, [74], None)) print (triton - ref) assert torch.allclose(triton, ref, atol=1e-2, rtol=1e-2) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('None', benchmark_compiled_module) ```
1 parent 2f81756 commit 0e868f8

File tree

2 files changed

+2
-96
lines changed

2 files changed

+2
-96
lines changed

test/Conversion/amd/fdivide.mlir

Lines changed: 0 additions & 41 deletions
This file was deleted.

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,62 +1408,9 @@ struct FDivOpConversion
14081408
ConversionPatternRewriter &rewriter,
14091409
Type elemTy, MultipleOperandsRange operands,
14101410
Location loc) const {
1411-
// For non-F32 input, it's lowered to LLVM::FDivOp, which is a
1412-
// IEEE-compliant DIV operation.
1413-
if (elemTy.getIntOrFloatBitWidth() != 32)
1414-
return {rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0][0],
1415-
operands[0][1])};
1416-
1417-
auto b = TritonLLVMOpBuilder(loc, rewriter);
14181411

1419-
// The algorithm comes from
1420-
// https://github.com/llvm/llvm-project/blob/bda7aadf/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L4980-L5065
1421-
// with the Newton-Raphson refinement removed, to perform a faster,
1422-
// approximated DIV operation, aligning with the `div.full.f32` instruction
1423-
// on the NV backend.
1424-
Value &lhs = operands[0][0];
1425-
Value &rhs = operands[0][1];
1426-
MLIRContext *ctx = rewriter.getContext();
1427-
Type divScaleResType = struct_ty({elemTy, i1_ty});
1428-
1429-
// The `llvm.amdgcn.div.scale.f32` instruction's signature is
1430-
// (src0, src1, src2) -> (ret0, ret1), where
1431-
//
1432-
// src0: The numerator or lhs of FDivOp.
1433-
// src1: The denominator or rhs of FDivOp.
1434-
// src2: A boolean indicating which operand to scale. If true, lhs is
1435-
// scaled; Otherwise, rhs is scaled.
1436-
//
1437-
// ret0: The scaled operand.
1438-
// ret1: The VCC register indicating whether post-scaling is required.
1439-
auto denominatorScaleOp = LLVM::createLLVMIntrinsicCallOp(
1440-
rewriter, loc, "llvm.amdgcn.div.scale.f32", divScaleResType,
1441-
{lhs, rhs, b.false_val()});
1442-
Value denominatorScaled = b.extract_val(denominatorScaleOp.getResult(0), 0);
1443-
auto numeratorScaleOp = LLVM::createLLVMIntrinsicCallOp(
1444-
rewriter, loc, "llvm.amdgcn.div.scale.f32", divScaleResType,
1445-
{lhs, rhs, b.true_val()});
1446-
Value numeratorScaled = b.extract_val(numeratorScaleOp.getResult(0), 0);
1447-
Value vcc = b.extract_val(numeratorScaleOp.getResult(0), 1);
1448-
1449-
Value rcp =
1450-
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.rcp.f32",
1451-
elemTy, {denominatorScaled})
1452-
.getResult(0);
1453-
1454-
Value approxDiv = b.fmul(numeratorScaled, rcp);
1455-
1456-
// Since the Newton-Raphson is skipped, we use 0 instead of approximations
1457-
// as the inputs.
1458-
auto fmas = LLVM::createLLVMIntrinsicCallOp(
1459-
rewriter, loc, "llvm.amdgcn.div.fmas.f32", elemTy,
1460-
{b.f32_val(0), b.f32_val(0), approxDiv, vcc})
1461-
.getResult(0);
1462-
1463-
return {LLVM::createLLVMIntrinsicCallOp(rewriter, loc,
1464-
"llvm.amdgcn.div.fixup.f32", elemTy,
1465-
{fmas, rhs, lhs})
1466-
.getResult(0)};
1412+
return {rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0][0],
1413+
operands[0][1])};
14671414
}
14681415
};
14691416

0 commit comments

Comments
 (0)