diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 84f1f57d12..a68d26ae9e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1102,7 +1102,11 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): if is_xpu(): # use cpu result as reference, see https://github.com/llvm/llvm-project/issues/88222 - out_ref = torch.div(x.cpu().to(torch.float64), y.cpu().to(torch.float64)).to(torch.float32).to(device=device) + if (expr_prec.count('sqrt') > 0): + out_ref = torch.sqrt(x.cpu().to(torch.float64)).to(torch.float32).to(device=device) + elif (expr_prec.count('div') > 0): + out_ref = torch.div(x.cpu().to(torch.float64), + y.cpu().to(torch.float64)).to(torch.float32).to(device=device) assert torch.all(out == out_ref) # bitwise exact diff --git a/scripts/skiplist/a770/language.txt b/scripts/skiplist/a770/language.txt index 7e3e8d62fc..7025af3afe 100644 --- a/scripts/skiplist/a770/language.txt +++ b/scripts/skiplist/a770/language.txt @@ -1,5 +1,3 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] diff --git a/scripts/skiplist/conda/language.txt b/scripts/skiplist/conda/language.txt index 1f2dcf0d10..d31a3fb96f 100644 --- a/scripts/skiplist/conda/language.txt +++ b/scripts/skiplist/conda/language.txt @@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128- test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256] test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256] test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] @@ -278,7 +276,6 @@ test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16] test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32] test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32] test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] test/unit/language/test_core.py::test_bin_op[1-uint8-uint8-%] test/unit/language/test_core.py::test_bin_op[1-uint8-uint16-%] test/unit/language/test_core.py::test_bin_op[1-uint8-uint32-%] diff --git a/scripts/skiplist/default/language.txt b/scripts/skiplist/default/language.txt index a891b802b5..d408f0d8c3 100644 --- a/scripts/skiplist/default/language.txt +++ b/scripts/skiplist/default/language.txt @@ -1,4 +1,2 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] diff --git a/scripts/skiplist/lts/language.txt b/scripts/skiplist/lts/language.txt index c2842cdb91..a652a92cf5 100644 --- a/scripts/skiplist/lts/language.txt +++ b/scripts/skiplist/lts/language.txt @@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128- test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256] test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256] test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2662 test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 diff --git a/scripts/skiplist/mtl/language.txt b/scripts/skiplist/mtl/language.txt index 69530824f3..9f4888e59e 100644 --- a/scripts/skiplist/mtl/language.txt +++ b/scripts/skiplist/mtl/language.txt @@ -1,5 +1,3 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] diff --git a/scripts/skiplist/xe2/language.txt b/scripts/skiplist/xe2/language.txt index a891b802b5..d408f0d8c3 100644 --- a/scripts/skiplist/xe2/language.txt +++ b/scripts/skiplist/xe2/language.txt @@ -1,4 +1,2 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 -test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] diff --git a/third_party/intel/language/intel/libdevice.py b/third_party/intel/language/intel/libdevice.py index e673909deb..29b9a77191 100644 --- a/third_party/intel/language/intel/libdevice.py +++ b/third_party/intel/language/intel/libdevice.py @@ -212,6 +212,34 @@ def rcp_ru(arg0, _builder=None): }, is_pure=True, _builder=_builder) +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__imf_sqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__imf_sqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__imf_sqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__imf_sqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + @core.extern def sqrt(arg0, _builder=None): return core.extern_elementwise( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index 2cd885635d..110871cf0e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1420,6 +1420,34 @@ struct MulhiUIOpConversion const TargetInfoBase &targetInfo; }; +struct PreciseSqrtOpConversion + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(PreciseSqrtOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Value input = operands[0][0]; + Type origTy = input.getType(); + if (!origTy.isF64()) + input = fpext(f64_ty, input); + Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty}); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType); + LLVM::CallOp callOp = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input}); + callOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + Value result = callOp.getResult(); + if (!origTy.isF64()) + result = rewriter.create(loc, origTy, result); + return {result}; + } +}; + template struct OpToExternCallConversion : public ElementwiseOpConversionBase>( - typeConverter, axisInfoAnalysis, "__imf_sqrtf", benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); patterns.add>( typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit);