From ec594cba5378b5fab9a2aa3ae52807d568b95582 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 29 Jul 2024 21:41:21 +0000 Subject: [PATCH 1/4] Replaces _resolve_weak_types_comparisons with _resolve_weak_types_all_py_ints This new weak type resolver checks if the scalar is outside of the range of the strong data type and if so, returns the minimum scalar type for the value. --- dpctl/tensor/_elementwise_funcs.py | 15 ++++++++------- dpctl/tensor/_type_utils.py | 27 ++++++++++++++------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 84ca205a3c..0329c5cbe1 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -22,7 +22,7 @@ _acceptance_fn_negative, _acceptance_fn_reciprocal, _acceptance_fn_subtract, - _resolve_weak_types_comparisons, + _resolve_weak_types_all_py_ints, ) # U01: ==== ABS (x) @@ -661,6 +661,7 @@ _divide_docstring_, binary_inplace_fn=ti._divide_inplace, acceptance_fn=_acceptance_fn_divide, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _divide_docstring_ @@ -695,7 +696,7 @@ ti._equal_result_type, ti._equal, _equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _equal_docstring_ @@ -854,7 +855,7 @@ ti._greater_result_type, ti._greater, _greater_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _greater_docstring_ @@ -890,7 +891,7 @@ ti._greater_equal_result_type, ti._greater_equal, _greater_equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _greater_equal_docstring_ @@ -1041,7 +1042,7 @@ ti._less_result_type, ti._less, _less_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _less_docstring_ @@ -1077,7 +1078,7 @@ ti._less_equal_result_type, ti._less_equal, _less_equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _less_equal_docstring_ @@ -1552,7 +1553,7 @@ ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _not_equal_docstring_ diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 691f538336..262be2b03d 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): return o1_dtype, o2_dtype -def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): - "Resolves weak data type per NEP-0050 for comparisons," - "where result type is known to be `bool` and special behavior" - "is needed to handle mixed integer kinds" +def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev): + "Resolves weak data type per NEP-0050 for comparisons and" + " divide, where result type is known and special behavior" + "is needed to handle mixed integer kinds and Python integers" + "without overflow" if _is_weak_dtype(o1_dtype): if _is_weak_dtype(o2_dtype): raise ValueError @@ -415,10 +416,10 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): return _to_device_supported_dtype(dpt.float64, dev), o2_dtype else: if isinstance(o1_dtype, WeakIntegralType): - if o2_dtype.kind == "u": - # Python scalar may be negative, assumes mixed int loops - # exist - return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + o1_val = o1_dtype.get() + o2_iinfo = dpt.iinfo(o2_dtype) + if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max): + return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype return o2_dtype, o2_dtype elif _is_weak_dtype(o2_dtype): o1_kind_num = _strong_dtype_num_kind(o1_dtype) @@ -436,10 +437,10 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): ) else: if isinstance(o2_dtype, WeakIntegralType): - if o1_dtype.kind == "u": - # Python scalar may be negative, assumes mixed int loops - # exist - return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + o2_val = o2_dtype.get() + o1_iinfo = dpt.iinfo(o1_dtype) + if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max): + return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val)) return o1_dtype, o1_dtype else: return o1_dtype, o2_dtype @@ -834,7 +835,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q): "_acceptance_fn_negative", "_acceptance_fn_subtract", "_resolve_weak_types", - "_resolve_weak_types_comparisons", + "_resolve_weak_types_all_py_ints", "_weak_type_num_kind", "_strong_dtype_num_kind", "can_cast", From 0e807cb4fdcad712b53cb00aa29b85cc7963261c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 29 Jul 2024 23:42:52 +0000 Subject: [PATCH 2/4] _resolve_weak_types_all_py_ints only range-checks for integers --- dpctl/tensor/_type_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 262be2b03d..cc742b1b31 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -415,7 +415,9 @@ def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev): ) return _to_device_supported_dtype(dpt.float64, dev), o2_dtype else: - if isinstance(o1_dtype, WeakIntegralType): + if o1_kind_num == o2_kind_num and isinstance( + o1_dtype, WeakIntegralType + ): o1_val = o1_dtype.get() o2_iinfo = dpt.iinfo(o2_dtype) if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max): @@ -436,7 +438,9 @@ def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev): _to_device_supported_dtype(dpt.float64, dev), ) else: - if isinstance(o2_dtype, WeakIntegralType): + if o1_kind_num == o2_kind_num and isinstance( + o2_dtype, WeakIntegralType + ): o2_val = o2_dtype.get() o1_iinfo = dpt.iinfo(o1_dtype) if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max): From 9afca418ecc2713e7f18801e3a5f7d99369afe5a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 30 Jul 2024 23:10:09 +0000 Subject: [PATCH 3/4] Adds a test for `divide` behavior pointed out in gh-1711 --- dpctl/tests/elementwise/test_divide.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dpctl/tests/elementwise/test_divide.py b/dpctl/tests/elementwise/test_divide.py index 589f5237d1..610d0ccf31 100644 --- a/dpctl/tests/elementwise/test_divide.py +++ b/dpctl/tests/elementwise/test_divide.py @@ -256,3 +256,18 @@ def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype): else: with pytest.raises(ValueError): dpt.divide(ar1, ar2, out=ar2) + + +def test_divide_gh_1711(): + "See https://github.com/IntelPython/dpctl/issues/1711" + get_queue_or_skip() + + res = dpt.divide(-4, dpt.asarray(1, dtype="u4")) + assert isinstance(res, dpt.usm_ndarray) + assert res.dtype.kind == "f" + assert dpt.allclose(res, -4 / dpt.asarray(1, dtype="i4")) + + res = dpt.divide(dpt.asarray(3, dtype="u4"), -2) + assert isinstance(res, dpt.usm_ndarray) + assert res.dtype.kind == "f" + assert dpt.allclose(res, dpt.asarray(3, dtype="i4") / -2) From 655a5d995679ba2055d5316626b7c9dddeab1663 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 30 Jul 2024 23:19:53 +0000 Subject: [PATCH 4/4] Adds tests for fixed comparison behavior with very large Python integers --- dpctl/tests/elementwise/test_greater.py | 14 ++++++++++++++ dpctl/tests/elementwise/test_greater_equal.py | 14 ++++++++++++++ dpctl/tests/elementwise/test_less.py | 14 ++++++++++++++ dpctl/tests/elementwise/test_less_equal.py | 14 ++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/dpctl/tests/elementwise/test_greater.py b/dpctl/tests/elementwise/test_greater.py index d9fd852f18..248ea6bce4 100644 --- a/dpctl/tests/elementwise/test_greater.py +++ b/dpctl/tests/elementwise/test_greater.py @@ -281,3 +281,17 @@ def test_greater_mixed_integer_kinds(): # Python scalar assert dpt.all(dpt.greater(x2, -1)) assert not dpt.any(dpt.greater(-1, x2)) + + +def test_greater_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert py_int > x + assert not dpt.greater(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert x > -1 + assert not dpt.greater(-1, x) diff --git a/dpctl/tests/elementwise/test_greater_equal.py b/dpctl/tests/elementwise/test_greater_equal.py index 0f24aaa9b4..afe98f5026 100644 --- a/dpctl/tests/elementwise/test_greater_equal.py +++ b/dpctl/tests/elementwise/test_greater_equal.py @@ -280,3 +280,17 @@ def test_greater_equal_mixed_integer_kinds(): # Python scalar assert dpt.all(dpt.greater_equal(x2, -1)) assert not dpt.any(dpt.greater_equal(-1, x2)) + + +def test_greater_equal_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert py_int >= x + assert not dpt.greater_equal(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert x >= -1 + assert not dpt.greater_equal(-1, x) diff --git a/dpctl/tests/elementwise/test_less.py b/dpctl/tests/elementwise/test_less.py index b1cb497b04..6439e29e13 100644 --- a/dpctl/tests/elementwise/test_less.py +++ b/dpctl/tests/elementwise/test_less.py @@ -281,3 +281,17 @@ def test_less_mixed_integer_kinds(): # Python scalar assert not dpt.any(dpt.less(x2, -1)) assert dpt.all(dpt.less(-1, x2)) + + +def test_less_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert not py_int < x + assert dpt.less(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert not x < -1 + assert dpt.less(-1, x) diff --git a/dpctl/tests/elementwise/test_less_equal.py b/dpctl/tests/elementwise/test_less_equal.py index e189d94cdc..eca4a8fd68 100644 --- a/dpctl/tests/elementwise/test_less_equal.py +++ b/dpctl/tests/elementwise/test_less_equal.py @@ -280,3 +280,17 @@ def test_less_equal_mixed_integer_kinds(): # Python scalar assert not dpt.any(dpt.less_equal(x2, -1)) assert dpt.all(dpt.less_equal(-1, x2)) + + +def test_less_equal_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert not py_int <= x + assert dpt.less_equal(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert not x <= -1 + assert dpt.less_equal(-1, x)