Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,11 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr):

if is_xpu():
# use cpu result as reference, see https://github.com/llvm/llvm-project/issues/88222
out_ref = torch.div(x.cpu().to(torch.float64), y.cpu().to(torch.float64)).to(torch.float32).to(device=device)
if (expr_prec.count('sqrt') > 0):
out_ref = torch.sqrt(x.cpu().to(torch.float64)).to(torch.float32).to(device=device)
elif (expr_prec.count('div') > 0):
out_ref = torch.div(x.cpu().to(torch.float64),
y.cpu().to(torch.float64)).to(torch.float32).to(device=device)
assert torch.all(out == out_ref) # bitwise exact


Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/a770/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
Expand Down
3 changes: 0 additions & 3 deletions scripts/skiplist/conda/language.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
Expand Down Expand Up @@ -278,7 +276,6 @@ test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32]
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8]
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
test/unit/language/test_core.py::test_bin_op[1-uint8-uint8-%]
test/unit/language/test_core.py::test_bin_op[1-uint8-uint16-%]
test/unit/language/test_core.py::test_bin_op[1-uint8-uint32-%]
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/default/language.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
2 changes: 0 additions & 2 deletions scripts/skiplist/lts/language.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/mtl/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/xe2/language.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
28 changes: 28 additions & 0 deletions third_party/intel/language/intel/libdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,34 @@ def rcp_ru(arg0, _builder=None):
}, is_pure=True, _builder=_builder)


@core.extern
def sqrt_rn(arg0, _builder=None):
return core.extern_elementwise("", "", [arg0], {
(core.dtype("fp64"), ): ("__imf_sqrt_rn", core.dtype("fp64")),
}, is_pure=True, _builder=_builder)


@core.extern
def sqrt_rz(arg0, _builder=None):
return core.extern_elementwise("", "", [arg0], {
(core.dtype("fp64"), ): ("__imf_sqrt_rz", core.dtype("fp64")),
}, is_pure=True, _builder=_builder)


@core.extern
def sqrt_rd(arg0, _builder=None):
return core.extern_elementwise("", "", [arg0], {
(core.dtype("fp64"), ): ("__imf_sqrt_rd", core.dtype("fp64")),
}, is_pure=True, _builder=_builder)


@core.extern
def sqrt_ru(arg0, _builder=None):
return core.extern_elementwise("", "", [arg0], {
(core.dtype("fp64"), ): ("__imf_sqrt_ru", core.dtype("fp64")),
}, is_pure=True, _builder=_builder)


@core.extern
def sqrt(arg0, _builder=None):
return core.extern_elementwise(
Expand Down
32 changes: 30 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,34 @@ struct MulhiUIOpConversion
const TargetInfoBase &targetInfo;
};

struct PreciseSqrtOpConversion
: ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> {
using Base =
ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(PreciseSqrtOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
Value input = operands[0][0];
Type origTy = input.getType();
if (!origTy.isF64())
input = fpext(f64_ty, input);
Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty});
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType);
LLVM::CallOp callOp =
LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input});
callOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
Value result = callOp.getResult();
if (!origTy.isF64())
result = rewriter.create<LLVM::FPTruncOp>(loc, origTy, result);
return {result};
}
};

template <typename TritonOp>
struct OpToExternCallConversion
: public ElementwiseOpConversionBase<TritonOp,
Expand Down Expand Up @@ -1462,8 +1490,8 @@ void populateElementwiseOpToLLVMPatterns(
PatternBenefit benefit) {
using namespace mlir::triton::gpu;

patterns.add<OpToExternCallConversion<triton::PreciseSqrtOp>>(
typeConverter, axisInfoAnalysis, "__imf_sqrtf", benefit);
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
benefit);
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit);

Expand Down