Skip to content

Commit 5e7d368

Browse files
Disable type casting for pow() int arguments (#543)
1 parent 3731ce5 commit 5e7d368

File tree

3 files changed

+83
-10
lines changed

3 files changed

+83
-10
lines changed

pytato/array.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ def _binary_op(
582582
np.dtype[Any]] = _np_result_dtype,
583583
reverse: bool = False,
584584
cast_to_result_dtype: bool = True,
585+
is_pow: bool = False,
585586
) -> Array:
586587

587588
# {{{ sanity checks
@@ -601,14 +602,16 @@ def _binary_op(
601602
get_result_type,
602603
tags=tags,
603604
non_equality_tags=non_equality_tags,
604-
cast_to_result_dtype=cast_to_result_dtype)
605+
cast_to_result_dtype=cast_to_result_dtype,
606+
is_pow=is_pow)
605607
else:
606608
result = utils.broadcast_binary_op(
607609
self, other, op,
608610
get_result_type,
609611
tags=tags,
610612
non_equality_tags=non_equality_tags,
611-
cast_to_result_dtype=cast_to_result_dtype)
613+
cast_to_result_dtype=cast_to_result_dtype,
614+
is_pow=is_pow)
612615

613616
assert isinstance(result, Array)
614617
return result
@@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array:
648651
__rtruediv__ = partialmethod(_binary_op, operator.truediv,
649652
get_result_type=_truediv_result_type, reverse=True)
650653

651-
__pow__ = partialmethod(_binary_op, operator.pow)
652-
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True)
654+
__pow__ = partialmethod(_binary_op, operator.pow, is_pow=True)
655+
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True)
653656

654657
__neg__ = partialmethod(_unary_op, operator.neg)
655658

@@ -2403,7 +2406,8 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool:
24032406
lambda x, y: np.dtype(np.bool_),
24042407
tags=_get_default_tags(),
24052408
non_equality_tags=_get_created_at_tag(stacklevel=2),
2406-
cast_to_result_dtype=False
2409+
cast_to_result_dtype=False,
2410+
is_pow=False,
24072411
) # type: ignore[return-value]
24082412

24092413

@@ -2467,6 +2471,7 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
24672471
tags=_get_default_tags(),
24682472
non_equality_tags=_get_created_at_tag(),
24692473
cast_to_result_dtype=False,
2474+
is_pow=False,
24702475
) # type: ignore[return-value]
24712476

24722477

@@ -2484,6 +2489,7 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
24842489
tags=_get_default_tags(),
24852490
non_equality_tags=_get_created_at_tag(),
24862491
cast_to_result_dtype=False,
2492+
is_pow=False,
24872493
) # type: ignore[return-value]
24882494

24892495

pytato/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
195195
tags: frozenset[Tag],
196196
non_equality_tags: frozenset[Tag],
197197
cast_to_result_dtype: bool,
198+
is_pow: bool,
198199
) -> ArrayOrScalar:
199200
from pytato.array import _get_default_axes
200201

@@ -225,9 +226,19 @@ def cast_to_result_type(
225226
# Loopy's type casts don't like casting to bool
226227
assert result_dtype != np.bool_
227228

228-
expr = TypeCast(result_dtype, expr)
229+
# See https://github.com/inducer/pytato/issues/542
230+
# on why pow() + integers is not typecast to float or complex.
231+
if not (is_pow
232+
and np.issubdtype(array.dtype, np.integer)
233+
and not np.issubdtype(result_dtype, np.integer)):
234+
expr = TypeCast(result_dtype, expr)
229235
elif isinstance(expr, SCALAR_CLASSES):
230-
expr = result_dtype.type(expr)
236+
# See https://github.com/inducer/pytato/issues/542
237+
# on why pow() + integers is not typecast to float or complex.
238+
if not (is_pow
239+
and np.issubdtype(type(expr), np.integer)
240+
and not np.issubdtype(result_dtype, np.integer)):
241+
expr = result_dtype.type(expr)
231242

232243
return expr
233244

test/test_codegen.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,6 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse):
326326
"logical_and"))
327327
@pytest.mark.parametrize("reverse", (False, True))
328328
def test_array_array_binary_arith(ctx_factory, which, reverse):
329-
if which == "sub":
330-
pytest.skip("https://github.com/inducer/loopy/issues/131")
331-
332329
cl_ctx = ctx_factory()
333330
queue = cl.CommandQueue(cl_ctx)
334331
not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal",
@@ -2008,6 +2005,65 @@ def call_bar(tracer, x, y):
20082005
np.testing.assert_allclose(result_out[k], expect_out[k])
20092006

20102007

2008+
def test_pow_arg_casting(ctx_factory):
2009+
# Check that pow() arguments are not typecast from int
2010+
ctx = ctx_factory()
2011+
cq = cl.CommandQueue(ctx)
2012+
2013+
types = (int, np.int32, np.int64, float, np.float32, np.float64)
2014+
2015+
for base_scalar in (True, False):
2016+
for exponent_scalar in (True, False):
2017+
if base_scalar and exponent_scalar:
2018+
# Not supported in pytato
2019+
continue
2020+
2021+
for base_tp in types:
2022+
if base_scalar:
2023+
base_np = base_tp(2)
2024+
base = base_np
2025+
else:
2026+
base_np = np.array([1, 2, 3], base_tp)
2027+
base = pt.make_data_wrapper(base_np)
2028+
2029+
for exponent_tp in types:
2030+
if exponent_scalar:
2031+
exponent_np = exponent_tp(2)
2032+
exponent = exponent_np
2033+
else:
2034+
exponent_np = np.array([1, 2, 3], exponent_tp)
2035+
exponent = pt.make_data_wrapper(exponent_np)
2036+
2037+
out = base ** exponent
2038+
2039+
_, (pt_out,) = pt.generate_loopy(out)(cq)
2040+
2041+
np_out = np.power(base_np, exponent_np)
2042+
2043+
assert pt_out.dtype == np_out.dtype
2044+
np.testing.assert_allclose(np_out, pt_out)
2045+
2046+
if np.issubdtype(exponent_tp, np.integer):
2047+
assert exponent_tp in (int, np.int32, np.int64)
2048+
2049+
if exponent_scalar:
2050+
# We do cast between different int types
2051+
assert (type(out.expr.exponent) in
2052+
(int, np.int32, np.int64)
2053+
or out.expr.exponent.dtype == np_out.dtype)
2054+
else:
2055+
assert out.bindings["_in1"].dtype in \
2056+
(int, np.int32, np.int64)
2057+
else:
2058+
assert exponent_tp in (float, np.float32, np.float64)
2059+
if exponent_scalar:
2060+
assert type(out.expr.exponent) == np_out.dtype \
2061+
or out.expr.exponent.dtype == np_out.dtype
2062+
else:
2063+
assert out.bindings["_in1"].dtype in \
2064+
(float, np.float32, np.float64)
2065+
2066+
20112067
if __name__ == "__main__":
20122068
if len(sys.argv) > 1:
20132069
exec(sys.argv[1])

0 commit comments

Comments
 (0)