Skip to content

Commit 2aae978

Browse files
[UT] Fix test_precise_math (#2776)
Before this PR, precise sqrt is lowered to `__imf_sqrtf`, which is not precise. `__imf_sqrt_rn` is added in the new device library updated in #2774. This PR uses `__imf_sqrt_rn` to get the precise sqrt implemented. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 1aebd43 commit 2aae978

File tree

9 files changed

+63
-16
lines changed

9 files changed

+63
-16
lines changed

python/test/unit/language/test_core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,11 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr):
11021102

11031103
if is_xpu():
11041104
# use cpu result as reference, see https://github.com/llvm/llvm-project/issues/88222
1105-
out_ref = torch.div(x.cpu().to(torch.float64), y.cpu().to(torch.float64)).to(torch.float32).to(device=device)
1105+
if (expr_prec.count('sqrt') > 0):
1106+
out_ref = torch.sqrt(x.cpu().to(torch.float64)).to(torch.float32).to(device=device)
1107+
elif (expr_prec.count('div') > 0):
1108+
out_ref = torch.div(x.cpu().to(torch.float64),
1109+
y.cpu().to(torch.float64)).to(torch.float32).to(device=device)
11061110
assert torch.all(out == out_ref) # bitwise exact
11071111

11081112

scripts/skiplist/a770/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
2-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
31
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
42
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
53
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]

scripts/skiplist/conda/language.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
113113
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
114114
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
115115
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
116-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
117-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
118116
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
119117
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
120118
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
@@ -278,7 +276,6 @@ test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16]
278276
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32]
279277
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32]
280278
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8]
281-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
282279
test/unit/language/test_core.py::test_bin_op[1-uint8-uint8-%]
283280
test/unit/language/test_core.py::test_bin_op[1-uint8-uint16-%]
284281
test/unit/language/test_core.py::test_bin_op[1-uint8-uint32-%]
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
2-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
31
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
42
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]

scripts/skiplist/lts/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
113113
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
114114
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
115115
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
116-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
117-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
118116
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
119117
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
120118
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703

scripts/skiplist/mtl/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
2-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
31
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
42
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
53
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]

scripts/skiplist/xe2/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
2-
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
31
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
42
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]

third_party/intel/language/intel/libdevice.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,34 @@ def rcp_ru(arg0, _builder=None):
212212
}, is_pure=True, _builder=_builder)
213213

214214

215+
@core.extern
216+
def sqrt_rn(arg0, _builder=None):
217+
return core.extern_elementwise("", "", [arg0], {
218+
(core.dtype("fp64"), ): ("__imf_sqrt_rn", core.dtype("fp64")),
219+
}, is_pure=True, _builder=_builder)
220+
221+
222+
@core.extern
223+
def sqrt_rz(arg0, _builder=None):
224+
return core.extern_elementwise("", "", [arg0], {
225+
(core.dtype("fp64"), ): ("__imf_sqrt_rz", core.dtype("fp64")),
226+
}, is_pure=True, _builder=_builder)
227+
228+
229+
@core.extern
230+
def sqrt_rd(arg0, _builder=None):
231+
return core.extern_elementwise("", "", [arg0], {
232+
(core.dtype("fp64"), ): ("__imf_sqrt_rd", core.dtype("fp64")),
233+
}, is_pure=True, _builder=_builder)
234+
235+
236+
@core.extern
237+
def sqrt_ru(arg0, _builder=None):
238+
return core.extern_elementwise("", "", [arg0], {
239+
(core.dtype("fp64"), ): ("__imf_sqrt_ru", core.dtype("fp64")),
240+
}, is_pure=True, _builder=_builder)
241+
242+
215243
@core.extern
216244
def sqrt(arg0, _builder=None):
217245
return core.extern_elementwise(

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,34 @@ struct MulhiUIOpConversion
14201420
const TargetInfoBase &targetInfo;
14211421
};
14221422

1423+
struct PreciseSqrtOpConversion
1424+
: ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> {
1425+
using Base =
1426+
ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion>;
1427+
using Base::Base;
1428+
using Adaptor = typename Base::OpAdaptor;
1429+
1430+
SmallVector<Value> createDestOps(PreciseSqrtOp op, Adaptor adaptor,
1431+
ConversionPatternRewriter &rewriter,
1432+
Type elemTy, MultipleOperandsRange operands,
1433+
Location loc) const {
1434+
Value input = operands[0][0];
1435+
Type origTy = input.getType();
1436+
if (!origTy.isF64())
1437+
input = fpext(f64_ty, input);
1438+
Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty});
1439+
LLVM::LLVMFuncOp funcOp =
1440+
appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType);
1441+
LLVM::CallOp callOp =
1442+
LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input});
1443+
callOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1444+
Value result = callOp.getResult();
1445+
if (!origTy.isF64())
1446+
result = rewriter.create<LLVM::FPTruncOp>(loc, origTy, result);
1447+
return {result};
1448+
}
1449+
};
1450+
14231451
template <typename TritonOp>
14241452
struct OpToExternCallConversion
14251453
: public ElementwiseOpConversionBase<TritonOp,
@@ -1462,8 +1490,8 @@ void populateElementwiseOpToLLVMPatterns(
14621490
PatternBenefit benefit) {
14631491
using namespace mlir::triton::gpu;
14641492

1465-
patterns.add<OpToExternCallConversion<triton::PreciseSqrtOp>>(
1466-
typeConverter, axisInfoAnalysis, "__imf_sqrtf", benefit);
1493+
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
1494+
benefit);
14671495
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
14681496
typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit);
14691497

0 commit comments

Comments
 (0)