From 5148e6a2fe135be9b212c2ccf5a537f8d04c2468 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 3 Mar 2026 09:32:37 -0800 Subject: [PATCH 1/3] Remove a stray attribute of ZerosLikeOp. --- pytato/raising.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytato/raising.py b/pytato/raising.py index 62174368b..3b35baf66 100644 --- a/pytato/raising.py +++ b/pytato/raising.py @@ -89,7 +89,6 @@ class BinaryOp(HighLevelOp): @dataclass(frozen=True, eq=True, repr=True) class ZerosLikeOp(HighLevelOp): - function: str x: Array From 71c47021811dda3d6a26f79075d1c1ee1accc00b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 9 Feb 2026 10:19:39 -0800 Subject: [PATCH 2/3] Introduce pt.push_index_to_materialized_nodes. --- pytato/__init__.py | 4 + .../push_index_to_materialized_nodes.py | 1335 +++++++++++++++++ 2 files changed, 1339 insertions(+) create mode 100644 pytato/transform/push_index_to_materialized_nodes.py diff --git a/pytato/__init__.py b/pytato/__init__.py index d99432942..720065e9f 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -162,6 +162,9 @@ def set_debug_enabled(flag: bool) -> None: from pytato.transform.lower_to_index_lambda import to_index_lambda from pytato.transform.materialize import materialize_with_mpms from pytato.transform.metadata import unify_axes_tags +from pytato.transform.push_index_to_materialized_nodes import ( + push_index_to_materialized_nodes, +) from pytato.transform.remove_broadcasts_einsum import rewrite_einsums_with_no_broadcasts from pytato.visualization import ( get_dot_graph, @@ -264,6 +267,7 @@ def set_debug_enabled(flag: bool) -> None: "ones_like", "pad", "prod", + "push_index_to_materialized_nodes", "real", "reshape", "rewrite_einsums_with_no_broadcasts", diff --git a/pytato/transform/push_index_to_materialized_nodes.py b/pytato/transform/push_index_to_materialized_nodes.py new file mode 100644 index 000000000..6664391bd --- /dev/null +++ b/pytato/transform/push_index_to_materialized_nodes.py @@ -0,0 +1,1335 @@ +from __future__ import annotations + +from pytato.scalar_expr import INT_CLASSES + + +__copyright__ = """Copyright (C) 2026 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from abc import ABC +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Literal, cast + +import numpy as np +from constantdict import constantdict +from typing_extensions import override + +from pytato.array import ( + Array, + Axis, + DataWrapper, + DictOfNamedArrays, + IndexBase, + IndexExpr, + IndexLambda, + InputArgumentBase, + NormalizedSlice, + Placeholder, + ShapeComponent, + ShapeType, + expand_dims, + transpose, + zeros, +) +from pytato.loopy import LoopyCall +from pytato.raising import ( + BinaryOp, + BinaryOpType, + BroadcastOp, + C99CallOp, + LogicalNotOp, + ReduceOp, + WhereOp, + ZerosLikeOp, + index_lambda_to_high_level_op, +) +from pytato.tags import ImplStored +from pytato.transform import ArrayOrNames, CacheKeyT, TransformMapperWithExtraArgs +from pytato.utils import are_shape_components_equal, get_shape_after_broadcasting + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import TypeAlias + + from pymbolic import Scalar + from pymbolic.typing import Integer + + from pytato.array import ( + ArrayOrScalar, + AxisPermutation, + Concatenate, + Einsum, + IndexBase, + Reshape, + Roll, + SizeParam, + Stack, + ) + from pytato.function import NamedCallResult + from pytato.loopy import LoopyCallResult + from pytato.transform import ArrayOrNamesTc + +IndexesT: TypeAlias = tuple[IndexExpr, ...] + + +def _lower_binary_op_hlo(hlo: BinaryOp) -> Array: + """ + Returns a :class:`pytato.Array` corresponding to a binary operation + high-level op. + """ + from pytato.array import ( + equal, + greater, + greater_equal, + less, + less_equal, + logical_and, + logical_or, + not_equal, + ) + + assert isinstance(hlo.x1, Array) or isinstance(hlo.x2, Array) + # Note: We have a bunch of "pyright reportOperatorIssue" below, it does not + # respect the above runtime guard that at least one of x1 and x2 are of type + # Array. + + match hlo.binary_op: + case BinaryOpType.ADD: + return cast("Array", hlo.x1 + hlo.x2) + case BinaryOpType.SUB: + return cast( + "Array", hlo.x1 - hlo.x2 # pyright: ignore[reportOperatorIssue] + ) + case BinaryOpType.MULT: + return cast("Array", hlo.x1 * hlo.x2) + case BinaryOpType.LOGICAL_OR: + return cast("Array", logical_or(hlo.x1, hlo.x2)) + case BinaryOpType.LOGICAL_AND: + return cast("Array", logical_and(hlo.x1, hlo.x2)) + case BinaryOpType.BITWISE_OR: + return cast( + "Array", hlo.x1 | hlo.x2 # pyright: ignore[reportOperatorIssue] + ) + case BinaryOpType.BITWISE_AND: + return cast( + "Array", hlo.x1 & hlo.x2 # pyright: ignore[reportOperatorIssue] + ) + case BinaryOpType.BITWISE_XOR: + return cast( + "Array", hlo.x1 ^ hlo.x2 # pyright: ignore[reportOperatorIssue] + ) + case BinaryOpType.TRUEDIV: + return cast("Array", hlo.x1 / hlo.x2) + case BinaryOpType.FLOORDIV: + return cast( + "Array", hlo.x1 // hlo.x2 # pyright: ignore[reportOperatorIssue] + ) + case BinaryOpType.POWER: + return cast("Array", hlo.x1**hlo.x2) + case BinaryOpType.MOD: + return cast( + "Array", hlo.x1 % hlo.x2 # pyright: ignore[reportOperatorIssue] + ) + case BinaryOpType.LESS: + return cast("Array", less(hlo.x1, hlo.x2)) + case BinaryOpType.LESS_EQUAL: + return cast("Array", less_equal(hlo.x1, hlo.x2)) + case BinaryOpType.GREATER: + return cast("Array", greater(hlo.x1, hlo.x2)) + case BinaryOpType.GREATER_EQUAL: + return cast("Array", greater_equal(hlo.x1, hlo.x2)) + case BinaryOpType.EQUAL: + return cast("Array", equal(hlo.x1, hlo.x2)) + case BinaryOpType.NOT_EQUAL: + return cast("Array", not_equal(hlo.x1, hlo.x2)) + + +def _lower_call_op_hlo(hlo: C99CallOp) -> Array: + """ + Returns a :class:`pytato.Array` corresponding to a function high level op. + """ + + import pytato.cmath + from pytato.raising import PT_C99BINARY_FUNCS, PT_C99UNARY_FUNCS + + function = hlo.function + if function in {"asin", "acos", "atan", "atan2"}: + # these functions have different names on the numpy side vs on the C99 + # side. + function = "arc" + function[1:] + if hlo.function in PT_C99UNARY_FUNCS: + unary_mathfn = cast( + "Callable[[ArrayOrScalar], ArrayOrScalar]", + getattr(pytato.cmath, function), + ) + return cast("Array", unary_mathfn(hlo.args[0])) + else: + assert hlo.function in PT_C99BINARY_FUNCS + binary_mathfn = cast( + "Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar]", + getattr(pytato.cmath, function), + ) + return cast("Array", binary_mathfn(hlo.args[0], hlo.args[1])) + + +def _is_trivial_slice(dim: ShapeComponent, slice_: IndexExpr) -> bool: + """ + Returns *True* only if *slice_* represents the ``[:]`` index for an array's + axis of length *dim*. + """ + return ( + isinstance(slice_, NormalizedSlice) + and slice_.start == 0 + and slice_.step == 1 + and slice_.stop == dim + ) + + +def _is_materialized(x: Array) -> bool: + # TODO: Maybe in the later versions, think about LoopyCallResult, etc. + return isinstance(x, InputArgumentBase) or len(x.tags_of_type(ImplStored)) != 0 + + +def get_indexing_kind( + indices: Sequence[IndexExpr], +) -> Literal["basic", "contiguous_advanced", "non_contiguous_advanced"]: + """ + Returns what kind of :mod:`numpy` indexing does *indices* correspond to. + """ + from pytato.utils import partition + + i_adv_indices, i_basic_indices = partition( + lambda idx: isinstance(indices[idx], NormalizedSlice), range(len(indices)) + ) + if all( + isinstance(idx := indices[i_adv_idx], INT_CLASSES) + or (isinstance(idx, Array) and idx.ndim == 0) + for i_adv_idx in i_adv_indices + ): + return "basic" + elif any( + i_adv_indices[0] < i_basic_idx < i_adv_indices[-1] + for i_basic_idx in i_basic_indices + ): + return "non_contiguous_advanced" + else: + return "contiguous_advanced" + + +def _partition_into_adv_and_basic_indices( + indices: Sequence[IndexExpr], +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Returns the tuple ``(ia, ib)``, such that, for every ``i`` in ``ib``, + ``indices[i]`` is an instance of :class:`pytato.NormalizedSlice`, and, for + every ``j`` in ``ia``, ``indices[j]`` is an instance of + :class:`pytato.Array` or an integer. + """ + from pytato.utils import partition + + i_adv_indices, i_basic_indices = partition( + lambda idx: isinstance(indices[idx], NormalizedSlice), range(len(indices)) + ) + return tuple(i_adv_indices), tuple(i_basic_indices) + + +def _get_indices_shape(indices: Sequence[IndexExpr]) -> ShapeType: + """ + Returns the shape of the array constructed by the :mod:`numpy` styled + indexing: ``x[indices[0], indices[1], ...]``. + """ + from pytato.utils import _normalized_slice_len + + kind = get_indexing_kind(indices) + i_adv_indices, i_basic_indices = _partition_into_adv_and_basic_indices(indices) + if kind == "basic": + return tuple( + _normalized_slice_len(cast("NormalizedSlice", indices[i_idx])) + for i_idx in i_basic_indices + ) + elif kind == "contiguous_advanced": + return ( + tuple( + _normalized_slice_len(cast("NormalizedSlice", indices[i_idx])) + for i_idx in i_basic_indices + if i_idx < i_adv_indices[0] + ) + + get_shape_after_broadcasting( + [cast("Array | Scalar", indices[i_idx]) for i_idx in i_adv_indices] + ) + + tuple( + _normalized_slice_len(cast("NormalizedSlice", indices[i_idx])) + for i_idx in i_basic_indices + if i_idx > i_adv_indices[-1] + ) + ) + else: + assert kind == "non_contiguous_advanced" + return get_shape_after_broadcasting( + [cast("Array | Scalar", indices[i_idx]) for i_idx in i_adv_indices] + ) + tuple( + _normalized_slice_len(cast("NormalizedSlice", indices[i_idx])) + for i_idx in i_basic_indices + ) + + +def _get_indices_ndim(indices: Sequence[IndexExpr]) -> int: + """ + Returns the dimensionality of the array ``x[indices[0], indices[1], ...]``. + """ + return len([idx for idx in indices if isinstance(idx, NormalizedSlice)]) + len( + get_shape_after_broadcasting([idx for idx in indices if isinstance(idx, Array)]) + ) + + +@dataclass(frozen=True) +class AxisAccess(ABC): # noqa: B024 + """ + Records an index access expression along an array's axis. + """ + + +@dataclass(frozen=True) +class PointAccess(AxisAccess): + """ + Represents a single point access into an array's access. + """ + + point: Integer + + +@dataclass(frozen=True) +class SliceAccess(AxisAccess): + """ + Records a slice access of an axis that targets the axis :attr:`tgt_axis` in + the output. + + Consider X an array of shape ``(10, 10, 10, 10)`` which is indexed with + non-contiguous advanced indices as ``Y[_0, _1] = X[idx1[_0], 3, _1, idx2[_0]]``. In + this expression, the access to the 3rd axis from left is modeled as -- + ``SliceAccess(tgt_axis=1, slice_=NormalizedSlice(0, 10, 1))``. + """ + + tgt_axis: int + slice_: NormalizedSlice + + def __post_init__(self) -> None: + assert isinstance(self.tgt_axis, int) + assert isinstance(self.slice_, NormalizedSlice) + + +@dataclass(frozen=True) +class ArrayIndexAccess(AxisAccess): + """ + Records an array access of an axis that targets the axes :attr:`tgt_axes` in + the output. + + Consider X an array of shape ``(10, 10, 10, 10)`` which is indexed with + non-contiguous advanced indices as ``Y[_0, _1] = X[idx1[_0], 3, _1, idx2[_0, + _1]]``. In this expression, the access to the 4th axis from left is modeled + as -- ``ArrayIndexAccess(tgt_axes=(0, 1), ary=idx2)``. + """ + + tgt_axes: tuple[int, ...] + ary: Array + + def __post_init__(self) -> None: + assert isinstance(self.ary, Array) + assert self.ary.ndim == len(self.tgt_axes) + + +def _get_axis_accesses(indices: IndexesT) -> tuple[AxisAccess, ...]: + accesses: list[AxisAccess] = [] + kind = get_indexing_kind(indices) + i_adv_indices, _ = _partition_into_adv_and_basic_indices(indices) + adv_ndim = _get_indices_ndim([indices[i_idx] for i_idx in i_adv_indices]) + + if kind == "basic": + tgt_axis = 0 + for idx in indices: + if isinstance(idx, INT_CLASSES): + accesses.append(PointAccess(idx)) + elif isinstance(idx, Array): + assert idx.ndim == 0 + accesses.append(ArrayIndexAccess((), idx)) + else: + assert isinstance(idx, NormalizedSlice) + accesses.append(SliceAccess(tgt_axis, idx)) + tgt_axis += 1 + elif kind == "contiguous_advanced": + tgt_axis = 0 + + for i_idx in range(i_adv_indices[0]): + idx = indices[i_idx] + assert isinstance(idx, NormalizedSlice) + accesses.append(SliceAccess(tgt_axis, idx)) + tgt_axis += 1 + + for i_idx in i_adv_indices: + idx = indices[i_idx] + if isinstance(idx, INT_CLASSES): + accesses.append(PointAccess(idx)) + else: + assert isinstance(idx, Array) + accesses.append( + ArrayIndexAccess( + tuple( + range(tgt_axis + adv_ndim - idx.ndim, tgt_axis + adv_ndim) + ), + idx, + ) + ) + tgt_axis += adv_ndim + for i_idx in range(i_adv_indices[-1] + 1, len(indices)): + idx = indices[i_idx] + assert isinstance(idx, NormalizedSlice) + accesses.append(SliceAccess(tgt_axis, idx)) + tgt_axis += 1 + else: + assert kind == "non_contiguous_advanced" + slice_tgt_axis = adv_ndim + for idx in indices: + if isinstance(idx, INT_CLASSES): + accesses.append(PointAccess(idx)) + elif isinstance(idx, NormalizedSlice): + accesses.append(SliceAccess(slice_tgt_axis, idx)) + slice_tgt_axis += 1 + else: + assert isinstance(idx, Array) + accesses.append( + ArrayIndexAccess(tuple(range(adv_ndim - idx.ndim, adv_ndim)), idx) + ) + + assert len(accesses) == len(indices) + return tuple(accesses) + + +def _get_array_tgt_axes(accesses: tuple[AxisAccess, ...]) -> tuple[int, ...]: + """ + Returns all the target axes that contributions from an indirection access in + *accesses*. + """ + from functools import reduce + + return tuple( + sorted( + reduce( + lambda x1, x2: x1 | x2, + ( + frozenset(access.tgt_axes) + for access in accesses + if isinstance(access, ArrayIndexAccess) + ), + cast("frozenset[int]", frozenset()), + ) + ) + ) + + +def _permute_tgt_axes( + accesses: tuple[AxisAccess, ...], perm: tuple[int, ...] +) -> tuple[AxisAccess, ...]: + """ + Returns a transformed version of *accesses* such that the target axes are + permuted with the permutation *perm*. + """ + new_accesses: list[AxisAccess] = [] + for access in accesses: + if isinstance(access, PointAccess): + new_accesses.append(access) + elif isinstance(access, SliceAccess): + new_accesses.append(SliceAccess(perm[access.tgt_axis], access.slice_)) + else: + assert isinstance(access, ArrayIndexAccess) + new_tgt_axes = tuple(perm[tgt_axis] for tgt_axis in access.tgt_axes) + new_accesses.append(ArrayIndexAccess(new_tgt_axes, access.ary)) + + assert len(new_accesses) == len(accesses) + return tuple(new_accesses) + + +def _get_resulting_target_axes(accesses: Sequence[AxisAccess]) -> tuple[int, ...]: + from pytato.utils import partition + + i_adv_indices, i_basic_indices = partition( + lambda idx: isinstance(accesses[idx], SliceAccess), range(len(accesses)) + ) + + if all( + isinstance(accesses[i_adv_idx], PointAccess) + or ( + isinstance(accesses[i_adv_idx], ArrayIndexAccess) + and cast("ArrayIndexAccess", accesses[i_adv_idx]).ary.ndim == 0 + ) + for i_adv_idx in i_adv_indices + ): + kind = "basic" + elif any( + i_adv_indices[0] < i_basic_idx < i_adv_indices[-1] + for i_basic_idx in i_basic_indices + ): + kind = "non_contiguous_advanced" + else: + kind = "contiguous_advanced" + + if kind == "basic": + return tuple( + access.tgt_axis for access in accesses if isinstance(access, SliceAccess) + ) + elif kind == "contiguous_advanced": + pre_basic_tgts = tuple( + access.tgt_axis + for access in accesses[: i_adv_indices[0]] + if isinstance(access, SliceAccess) + ) + advanced_tgts = max( + [ + cast("ArrayIndexAccess", accesses[i_adv_idx]) + for i_adv_idx in i_adv_indices + if isinstance(accesses[i_adv_idx], ArrayIndexAccess) + ], + default=ArrayIndexAccess((), zeros(())), + key=lambda x: x.ary.ndim, + ).tgt_axes + + assert all( + cast("ArrayIndexAccess", accesses[i_adv_idx]).tgt_axes + == advanced_tgts[-cast("ArrayIndexAccess", accesses[i_adv_idx]).ary.ndim :] + for i_adv_idx in i_adv_indices + if isinstance(accesses[i_adv_idx], ArrayIndexAccess) + ) + post_basic_tgts = tuple( + access.tgt_axis + for access in accesses[i_adv_indices[-1] + 1 :] + if isinstance(access, SliceAccess) + ) + return pre_basic_tgts + advanced_tgts + post_basic_tgts + else: + assert kind == "non_contiguous_advanced" + basic_tgts = tuple( + cast("SliceAccess", accesses[i_idx]).tgt_axis for i_idx in i_basic_indices + ) + advanced_tgts = max( + [ + cast("ArrayIndexAccess", accesses[i_adv_idx]) + for i_adv_idx in i_adv_indices + if isinstance(accesses[i_adv_idx], ArrayIndexAccess) + ], + default=ArrayIndexAccess((), zeros(())), + key=lambda x: x.ary.ndim, + ).tgt_axes + + assert all( + cast("ArrayIndexAccess", accesses[i_adv_idx]).tgt_axes + == advanced_tgts[-cast("ArrayIndexAccess", accesses[i_adv_idx]).ary.ndim :] + for i_adv_idx in i_adv_indices + if isinstance(accesses[i_adv_idx], ArrayIndexAccess) + ) + return advanced_tgts + basic_tgts + + +def _compose_axis_transposes( + inner_perm: tuple[int, ...], outer_perm: tuple[int, ...] +) -> tuple[int, ...]: + """ + Returns ``axis_perm`` such that ``pt.transpose(x, axis_perm) == + pt.transpose(pt.transpose(x, inner_perm), outer_perm)``. + """ + n = len(inner_perm) + assert n == len(outer_perm) + return tuple(inner_perm[outer_perm[i]] for i in range(n)) + + +def _compose_indices( + *, inner_indices: IndexesT, outer_indices: IndexesT +) -> tuple[tuple[int, ...], IndexesT]: + """ + Returns ``(axis_perm, indices)`` such that ``pt.transpose(x[indices], + axis_perm) == x[inner_indices][outer_indices]``. + """ + accesses_to_x = _get_axis_accesses(inner_indices) + accesses_to_x_inner = _get_axis_accesses(outer_indices) + adv_idx_shape_of_inner_indices = get_shape_after_broadcasting( + [idx for idx in inner_indices if isinstance(idx, Array)] + ) + adv_tgt_axes_of_x = _get_array_tgt_axes(accesses_to_x) + adv_tgt_axes_of_x_inner = _get_array_tgt_axes(accesses_to_x_inner) + + # {{{ identify additional advanced target axes. + + # These additional indirections in the composed indices force an axis + # permutation on the output. + + additional_adv_tgt_axes = tuple( + access_to_x_inner.tgt_axis + for iaxis, access_to_x_inner in enumerate(accesses_to_x_inner) + if (isinstance(access_to_x_inner, SliceAccess) and iaxis in adv_tgt_axes_of_x) + ) + + if not additional_adv_tgt_axes: + axis_perm_inv = tuple(range(_get_indices_ndim(outer_indices))) + else: + if get_indexing_kind(outer_indices) in {"contiguous_advanced", "basic"}: + first_adv_tgt_axis_of_x_inner = ( + adv_tgt_axes_of_x_inner[0] + if adv_tgt_axes_of_x_inner + else additional_adv_tgt_axes[0] + ) + last_adv_tgt_axis_of_x_inner = ( + adv_tgt_axes_of_x_inner[-1] + if adv_tgt_axes_of_x_inner + else additional_adv_tgt_axes[0] + ) + + additional_adv_tgt_axes_before = tuple( + tgt_axis + for tgt_axis in additional_adv_tgt_axes + if tgt_axis < first_adv_tgt_axis_of_x_inner + ) + additional_adv_tgt_axes_after = tuple( + tgt_axis + for tgt_axis in additional_adv_tgt_axes + if tgt_axis > last_adv_tgt_axis_of_x_inner + ) + axis_perm_inv = ( + tuple( + iaxis + for iaxis in range(first_adv_tgt_axis_of_x_inner) + if iaxis not in additional_adv_tgt_axes + ) + + (() if adv_tgt_axes_of_x_inner else (first_adv_tgt_axis_of_x_inner,)) + + additional_adv_tgt_axes_before + + adv_tgt_axes_of_x_inner + + additional_adv_tgt_axes_after + + tuple( + iaxis + for iaxis in range( + last_adv_tgt_axis_of_x_inner + 1, + _get_indices_ndim(outer_indices), + ) + if iaxis not in additional_adv_tgt_axes + ) + ) + else: + assert get_indexing_kind(outer_indices) == "non_contiguous_advanced" + axis_perm_inv = ( + adv_tgt_axes_of_x_inner + + additional_adv_tgt_axes + + tuple( + i + for i in range( + len(adv_tgt_axes_of_x_inner), _get_indices_ndim(outer_indices) + ) + if i not in additional_adv_tgt_axes + ) + ) + + # }}} + + axis_perm = tuple(cast("list[int]", np.argsort(axis_perm_inv).tolist())) + accesses_to_x_inner = _permute_tgt_axes( + accesses_to_x_inner, tuple(cast("list[int]", np.argsort(axis_perm).tolist())) + ) + adv_tgt_axes_of_x_inner = tuple( + sorted( + axis_perm_inv[iaxis] + for iaxis in (adv_tgt_axes_of_x_inner + additional_adv_tgt_axes) + ) + ) + del axis_perm_inv + + composed_indices_to_x: list[IndexExpr] = [] + composed_accesses_to_x: list[AxisAccess] = [] + + for access_to_x in accesses_to_x: + if isinstance(access_to_x, PointAccess): + composed_indices_to_x.append(access_to_x.point) + composed_accesses_to_x.append(access_to_x) + elif isinstance(access_to_x, SliceAccess): + inner_slice = access_to_x.slice_ + access_to_x_inner = accesses_to_x_inner[access_to_x.tgt_axis] + if isinstance(access_to_x_inner, PointAccess): + composed_pt = ( + inner_slice.start + inner_slice.step * access_to_x_inner.point + ) + composed_indices_to_x.append(composed_pt) + if isinstance(composed_pt, Array): + composed_accesses_to_x.append(ArrayIndexAccess((), composed_pt)) + else: + composed_accesses_to_x.append(PointAccess(composed_pt)) + elif isinstance(access_to_x_inner, SliceAccess): + outer_slice = access_to_x_inner.slice_ + composed_slice = NormalizedSlice( + inner_slice.start + inner_slice.step * outer_slice.start, + inner_slice.start + inner_slice.step * outer_slice.stop, + inner_slice.step * outer_slice.step, + ) + composed_indices_to_x.append(composed_slice) + composed_accesses_to_x.append( + SliceAccess(access_to_x_inner.tgt_axis, composed_slice) + ) + else: + assert isinstance(access_to_x_inner, ArrayIndexAccess) + idx_ary = access_to_x_inner.ary + if inner_slice.step != 1: + idx_ary = inner_slice.step * idx_ary + if not are_shape_components_equal(inner_slice.start, 0): + idx_ary = idx_ary + inner_slice.start + dims_to_expand = tuple( + iaxis + for iaxis, tgt_axis in enumerate( + adv_tgt_axes_of_x_inner[ + adv_tgt_axes_of_x_inner.index( + min(access_to_x_inner.tgt_axes) + ) : + ] + ) + if tgt_axis not in access_to_x_inner.tgt_axes + ) + if dims_to_expand: + idx_ary = expand_dims(idx_ary, dims_to_expand) + + composed_indices_to_x.append(idx_ary) + composed_accesses_to_x.append( + ArrayIndexAccess(adv_tgt_axes_of_x_inner[-idx_ary.ndim :], idx_ary) + ) + else: + assert isinstance(access_to_x, ArrayIndexAccess) + tgt_axes = access_to_x.tgt_axes + resulting_tgt_axes = _get_resulting_target_axes( + [ + ( + PointAccess(0) + if isinstance(accesses_to_x_inner[tgt_axis], ArrayIndexAccess) + and not are_shape_components_equal( + adv_idx_shape_of_inner_indices[ + -access_to_x.ary.ndim + src_axis + ], + access_to_x.ary.shape[src_axis], + ) + else accesses_to_x_inner[tgt_axis] + ) + for src_axis, tgt_axis in enumerate(tgt_axes) + ] + ) + new_indices: list[Integer | slice | Array] = [] + for src_axis, tgt_axis in enumerate(tgt_axes): + access_to_x_inner = accesses_to_x_inner[tgt_axis] + if isinstance(access_to_x_inner, PointAccess): + new_indices.append(access_to_x_inner.point) + elif isinstance(access_to_x_inner, SliceAccess): + slice_ = access_to_x_inner.slice_ + if not are_shape_components_equal( + adv_idx_shape_of_inner_indices[ + -access_to_x.ary.ndim + src_axis + ], + access_to_x.ary.shape[src_axis], + ): + assert are_shape_components_equal( + access_to_x.ary.shape[src_axis], 1 + ) + new_indices.append(slice(0, 1, 1)) + else: + new_indices.append( + slice(slice_.start, slice_.stop, slice_.step) + ) + else: + assert isinstance(access_to_x_inner, ArrayIndexAccess) + if not are_shape_components_equal( + adv_idx_shape_of_inner_indices[ + -access_to_x.ary.ndim + src_axis + ], + access_to_x.ary.shape[src_axis], + ): + assert are_shape_components_equal( + access_to_x.ary.shape[src_axis], 1 + ) + new_indices.append(0) + else: + new_indices.append(access_to_x_inner.ary) + + axis_perm_inner = tuple( + cast("list[int]", np.argsort(resulting_tgt_axes).tolist()) + ) + dims_to_expand = ( + tuple( + iaxis + for iaxis, tgt_axis in enumerate( + adv_tgt_axes_of_x_inner[ + adv_tgt_axes_of_x_inner.index(min(resulting_tgt_axes)) : + ] + ) + if tgt_axis not in resulting_tgt_axes + ) + if adv_tgt_axes_of_x_inner + else () + ) + + new_ary = access_to_x.ary + if not all( + isinstance(idx, slice) + and _is_trivial_slice( + axis_len, + NormalizedSlice( + cast("ShapeComponent", idx.start), + cast("ShapeComponent", idx.stop), + cast("int", idx.step), + ), + ) + for idx, axis_len in zip( + new_indices, access_to_x.ary.shape, strict=True + ) + ): + new_ary = new_ary[tuple(new_indices)] + + if axis_perm_inner != tuple(range(new_ary.ndim)): + new_ary = transpose(new_ary, axis_perm_inner) + + if dims_to_expand: + new_ary = expand_dims(new_ary, dims_to_expand) + composed_indices_to_x.append(new_ary) + composed_accesses_to_x.append( + ArrayIndexAccess( + adv_tgt_axes_of_x_inner[-new_ary.ndim :], + new_ary, + ) + ) + + additional_axis_perm = tuple( + cast( + "list[int]", + np.argsort(_get_resulting_target_axes(composed_accesses_to_x)).tolist(), + ) + ) + + assert len(composed_indices_to_x) == len(inner_indices) + return _compose_axis_transposes(additional_axis_perm, axis_perm), tuple( + composed_indices_to_x + ) + + +def _get_indices_for_broadcast( + from_shape: ShapeType, to_shape: ShapeType, indices: IndexesT +) -> tuple[tuple[int, ...], tuple[int, ...], IndexesT]: + """ + Returns ``(axis_perm, dims_to_expand, new_indices)`` such that + ``pt.broadcast_to(x, to_shape)[indices] == + pt.transpose(pt.expand_dims(pt.broadcast_to(x[new_idxs], + _get_indexing_shape(indices)), dims_to_expand), axis_perm)``, where + ``x.shape == from_shape``. + """ + assert len(to_shape) == len(indices) + new_indices: list[IndexExpr] = [] + + for from_axis_len, to_axis_len, idx in zip( + from_shape, + to_shape[-len(from_shape) :], + indices[-len(from_shape) :], + strict=True, + ): + if are_shape_components_equal(from_axis_len, to_axis_len): + new_indices.append(idx) + else: + assert are_shape_components_equal(from_axis_len, 1) + if isinstance(idx, NormalizedSlice): + new_indices.append(NormalizedSlice(0, 1, 1)) + else: + new_indices.append(0) + + i_adv_axes, i_basic_axes = _partition_into_adv_and_basic_indices(indices) + i_new_adv_axes, i_new_basic_axes = _partition_into_adv_and_basic_indices( + new_indices + ) + adv_ndim = _get_indices_ndim([indices[iaxis] for iaxis in i_adv_axes]) + new_adv_ndim = _get_indices_ndim([new_indices[iaxis] for iaxis in i_new_adv_axes]) + + if get_indexing_kind(indices) == "basic": + assert get_indexing_kind(new_indices) == "basic" + dims_to_expand: tuple[int, ...] = () + axis_perm = tuple(range(_get_indices_ndim(new_indices))) + elif get_indexing_kind(indices) == "contiguous_advanced": + if get_indexing_kind(new_indices) == "basic": + n_new_pre_basic_axes = len( + [iaxis for iaxis in i_new_basic_axes if iaxis < i_new_adv_axes[0]] + ) + dims_to_expand = ( + tuple( + range( + n_new_pre_basic_axes, + n_new_pre_basic_axes + adv_ndim, + ) + ) + if n_new_pre_basic_axes + else () + ) + axis_perm = tuple(range(_get_indices_ndim(new_indices))) + else: + assert get_indexing_kind(new_indices) == "contiguous_advanced" + assert len( + [iaxis for iaxis in i_basic_axes if iaxis > i_adv_axes[-1]] + ) == len( + [iaxis for iaxis in i_new_basic_axes if iaxis > i_new_adv_axes[-1]] + ) + n_new_pre_basic_axes = len( + [iaxis for iaxis in i_new_basic_axes if iaxis < i_new_adv_axes[0]] + ) + + dims_to_expand = tuple( + range( + n_new_pre_basic_axes, n_new_pre_basic_axes + adv_ndim - new_adv_ndim + ) + ) + axis_perm = tuple(range(_get_indices_ndim(new_indices))) + else: + assert get_indexing_kind(indices) == "non_contiguous_advanced" + if get_indexing_kind(new_indices) == "basic": + dims_to_expand = () + axis_perm = tuple(range(_get_indices_ndim(new_indices))) + elif get_indexing_kind(new_indices) == "contiguous_advanced": + n_missing_basic_axes = len(i_basic_axes) - len(i_new_basic_axes) + dims_to_expand = tuple( + range( + _get_indices_ndim(new_indices), + _get_indices_ndim(new_indices) + n_missing_basic_axes, + ) + ) + n_new_pre_basic_axes = len( + [iaxis for iaxis in i_new_basic_axes if iaxis < i_new_adv_axes[0]] + ) + n_new_post_basic_axes = len( + [iaxis for iaxis in i_new_basic_axes if iaxis > i_new_adv_axes[-1]] + ) + axis_perm = ( + tuple(range(n_new_pre_basic_axes, n_new_pre_basic_axes + new_adv_ndim)) + + tuple( + range( + n_new_pre_basic_axes + new_adv_ndim + n_new_post_basic_axes, + n_missing_basic_axes + + n_new_pre_basic_axes + + new_adv_ndim + + n_new_post_basic_axes, + ) + ) + + tuple(range(n_new_pre_basic_axes)) + + tuple( + range( + n_new_pre_basic_axes + new_adv_ndim, + n_new_pre_basic_axes + new_adv_ndim + n_new_post_basic_axes, + ) + ) + ) + else: + assert get_indexing_kind(new_indices) == "non_contiguous_advanced" + n_missing_basic_axes = len(i_basic_axes) - len(i_new_basic_axes) + dims_to_expand = tuple( + range(new_adv_ndim, new_adv_ndim + n_missing_basic_axes) + ) + axis_perm = tuple( + range(_get_indices_ndim(new_indices) + n_missing_basic_axes) + ) + + return axis_perm, dims_to_expand, tuple(new_indices) + + +class IndexPusher(TransformMapperWithExtraArgs[[IndexesT]]): + @override + def get_cache_key(self, expr: ArrayOrNames, indices: IndexesT) -> CacheKeyT: + return (expr, indices) + + def rec_ary(self, expr: Array, indices: IndexesT) -> Array: + result = self.rec(expr, indices) + assert isinstance(result, Array) + return result + + def rec_w_passthru_indices(self, expr: Array) -> Array: + """ + Recurse over *expr* with all indices being the trivial slices, i.e. + ``slice()``. + """ + indices = tuple(NormalizedSlice(0, axis_len, 1) for axis_len in expr.shape) + return self.rec_ary(expr, indices) + + def rec_w_broadcast( + self, expr: Array, to_shape: ShapeType, indices: IndexesT + ) -> Array: + axis_perm, new_dims, new_indices = _get_indices_for_broadcast( + expr.shape, to_shape, indices + ) + expr = self.rec_ary(expr, new_indices) + expr = expand_dims(expr, new_dims) if new_dims else expr + return ( + transpose(expr, axis_perm) if axis_perm != tuple(range(expr.ndim)) else expr + ) + + def _eagerly_index(self, expr: Array, indices: IndexesT) -> Array: + """ + Returns *expr* index with *indices*, i.e. returns ``expr[indices]``. + + .. note:: + + No further attempts at propagating + *indices* to the predecessors of *expr* is done. + """ + assert expr.ndim == len(indices) + if all( + _is_trivial_slice(dim, idx) + for dim, idx in zip(expr.shape, indices, strict=True) + ): + return expr + else: + # explains why this + # is needed. + new_indices = tuple( + ( + slice(idx.start, idx.stop, idx.step) + if isinstance(idx, NormalizedSlice) + else idx + ) + for idx in indices + ) + return expr[new_indices] + + def map_placeholder(self, expr: Placeholder, indices: IndexesT) -> Array: + return self._eagerly_index(expr, indices) + + def map_data_wrapper(self, expr: DataWrapper, indices: IndexesT) -> Array: + return self._eagerly_index(expr, indices) + + def _map_index_base(self, expr: IndexBase, indices: IndexesT) -> Array: + assert len(indices) == expr.ndim + if _is_materialized(expr): + return self._eagerly_index(self.rec_ary(expr.array, expr.indices), indices) + else: + axis_perm, new_indices = _compose_indices( + outer_indices=indices, inner_indices=expr.indices + ) + reced_ary = self.rec_ary(expr.array, new_indices) + if axis_perm == tuple(range(len(axis_perm))): + return reced_ary + else: + return transpose(reced_ary, axis_perm) + + def map_basic_index(self, expr: IndexBase, indices: IndexesT) -> Array: + return self._map_index_base(expr, indices) + + def map_contiguous_advanced_index( + self, expr: IndexBase, indices: IndexesT + ) -> Array: + return self._map_index_base(expr, indices) + + def map_non_contiguous_advanced_index( + self, expr: IndexBase, indices: IndexesT + ) -> Array: + return self._map_index_base(expr, indices) + + def map_size_param(self, expr: SizeParam, indices: IndexesT) -> Array: + raise NotImplementedError + + def map_index_lambda(self, expr: IndexLambda, indices: IndexesT) -> Array: + if _is_materialized(expr): + new_expr = expr.replace_if_different( + bindings=constantdict( + { + name: self.rec_w_passthru_indices(bnd) + for name, bnd in expr.bindings.items() + } + ) + ) + return self._eagerly_index(new_expr, indices) + else: + hlo = index_lambda_to_high_level_op(expr) + + if isinstance(hlo, BinaryOp): + x1, x2 = hlo.x1, hlo.x2 + if isinstance(x1, Array): + x1 = self.rec_w_broadcast(x1, expr.shape, indices) + if isinstance(x2, Array): + x2 = self.rec_w_broadcast(x2, expr.shape, indices) + new_hlo = _lower_binary_op_hlo(replace(hlo, x1=x1, x2=x2)) + elif isinstance(hlo, BroadcastOp): + from pytato.array import broadcast_to + + x = hlo.x + if isinstance(x, Array): + x = self.rec_w_broadcast(x, expr.shape, indices) + + new_hlo = broadcast_to( + x, + _get_indices_shape(indices), + ) + elif isinstance(hlo, C99CallOp): + new_args = tuple( + ( + self.rec_w_broadcast(ary_arg, expr.shape, indices) + if isinstance(ary_arg, Array) + else ary_arg + ) + for ary_arg in hlo.args + ) + new_hlo = _lower_call_op_hlo(replace(hlo, args=new_args)) + elif isinstance(hlo, LogicalNotOp): + from pytato.array import logical_not + + new_hlo = logical_not(self.rec_ary(hlo.x, indices)) + elif isinstance(hlo, ReduceOp): + # TODO: Skipping for now. (Should be doable after figure out the + # appropriate transpose of the result.) + raise NotImplementedError + elif isinstance(hlo, WhereOp): + from pytato.array import where + + (cond, then, else_) = [ + ( + self.rec_w_broadcast(ary_arg, expr.shape, indices) + if isinstance(ary_arg, Array) + else ary_arg + ) + for ary_arg in [hlo.condition, hlo.then, hlo.else_] + ] + new_hlo = where(cond, then, else_) + elif isinstance(hlo, ZerosLikeOp): + from pytato.cmath import zeros_like + + new_hlo = zeros_like(self.rec_ary(hlo.x, indices), expr.dtype) + else: + raise NotImplementedError(type(hlo)) + assert isinstance(new_hlo, IndexLambda) + return expr.replace_if_different( + bindings=new_hlo.bindings, + expr=new_hlo.expr, + shape=new_hlo.shape, + axes=new_hlo.axes, + ) + + def map_concatenate(self, expr: Concatenate, indices: IndexesT) -> Array: + if _is_materialized(expr): + return self._eagerly_index( + expr.replace_if_different( + arrays=tuple( + self.rec_w_passthru_indices(ary) for ary in expr.arrays + ) + ), + indices, + ) + else: + # TODO: Skipping for now. (Should be doable with some swizzling + # of expr.arrays, don't see any need for it now.) + raise NotImplementedError + + def map_stack(self, expr: Stack, indices: IndexesT) -> Array: + if _is_materialized(expr): + return self._eagerly_index( + expr.replace_if_different( + arrays=tuple( + self.rec_w_passthru_indices(ary) for ary in expr.arrays + ) + ), + indices, + ) + else: + # TODO: Skipping for now. (Should be doable with some swizzling + # of expr.arrays, don't see any need for it now.) + if all( + _is_trivial_slice(axis_len, idx) + for axis_len, idx in zip(expr.shape, indices, strict=True) + ): + # handling a special case needed in grudge expressions + return expr.replace_if_different( + arrays=tuple( + self.rec_w_passthru_indices(ary) for ary in expr.arrays + ) + ) + raise NotImplementedError + + def map_roll(self, expr: Roll, indices: IndexesT) -> Array: + if _is_materialized(expr): + return self._eagerly_index( + expr.replace_if_different( + array=self.rec_w_passthru_indices(expr.array) + ), + indices, + ) + else: + # TODO: Skipping for now. (Should be doable with a modulo operation, + # don't see any need for it now.) + raise NotImplementedError + + def map_axis_permutation(self, expr: AxisPermutation, indices: IndexesT) -> Array: + if _is_materialized(expr): + return self._eagerly_index( + expr.replace_if_different( + array=self.rec_w_passthru_indices(expr.array) + ), + indices, + ) + else: + # TODO: Skipping for now. The permutation axes have to be changed to + # play with non-contiguous advanced indexing. Not needed for now. + raise NotImplementedError + + def map_reshape(self, expr: Reshape, indices: IndexesT) -> Array: + if _is_materialized(expr): + return self._eagerly_index( + expr.replace_if_different( + array=self.rec_w_passthru_indices(expr.array) + ), + indices, + ) + else: + # TODO: Skipping for now. For certain cases, like the expand_dims + # reshape, propagating the indices should be doable. Not needed for now. + if expr.order == "C": + # handle a special case needed in grudge expressions + matching_trailing_axes = 0 + for new_axis_len, old_axis_len in zip( + expr.shape[::-1], expr.array.shape[::-1], strict=False + ): + if are_shape_components_equal(old_axis_len, new_axis_len): + matching_trailing_axes += 1 + else: + break + leading_indices, trailing_indices = ( + indices[: expr.ndim - matching_trailing_axes], + indices[expr.ndim - matching_trailing_axes :], + ) + if all( + _is_trivial_slice(axis_len, idx) + for axis_len, idx in zip(expr.shape, leading_indices, strict=False) + ): + array_indices = ( + tuple( + NormalizedSlice(0, axis_len, 1) + for axis_len in expr.array.shape[ + : expr.array.ndim - matching_trailing_axes + ] + ) + + trailing_indices + ) + assert len(array_indices) == expr.array.ndim + newshape = _get_indices_shape(indices) + return expr.replace_if_different( + array=self.rec_ary(expr.array, array_indices), + newshape=( + expr.newshape if newshape == expr.newshape else newshape + ), + axes=( + expr.axes + if newshape == expr.newshape + else tuple(Axis(frozenset()) for _ in newshape) + ), + ) + + raise NotImplementedError + + def map_einsum(self, expr: Einsum, indices: IndexesT) -> Array: + if _is_materialized(expr): + return self._eagerly_index( + expr.replace_if_different( + args=tuple(self.rec_w_passthru_indices(arg) for arg in expr.args) + ), + indices, + ) + else: + # TODO: Skipping for now. Involves transmitting the indices to the + # operands as well as changing the subscript itself. + if all( + _is_trivial_slice(axis_len, idx) + for axis_len, idx in zip(expr.shape, indices, strict=True) + ): + # handling a special case needed in grudge expressions + return expr.replace_if_different( + args=tuple(self.rec_w_passthru_indices(arg) for arg in expr.args) + ) + raise NotImplementedError + + def map_loopy_call(self, expr: LoopyCall, indices: IndexesT) -> LoopyCall: + if indices != (): + raise ValueError("map_loopy_call must be called with outer indices = ().") + return expr.replace_if_different( + bindings=constantdict( + { + name: ( + self.rec_w_passthru_indices(bnd) + if isinstance(bnd, Array) + else bnd + ) + for name, bnd in expr.bindings.items() + } + ) + ) + + def map_loopy_call_result(self, expr: LoopyCallResult, indices: IndexesT) -> Array: + new_loopy_call = self.rec(expr._container, ()) + assert isinstance(new_loopy_call, LoopyCall) + return self._eagerly_index(new_loopy_call[expr.name], indices) + + def map_named_call_result(self, expr: NamedCallResult, indices: IndexesT) -> Array: + # TODO: Maybe we should propagate indices to the function definition itself? + raise NotImplementedError( + "NamedCall results currently not supported in" + " push_index_to_materialized_nodes." + ) + + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays, indices: IndexesT + ) -> DictOfNamedArrays: + if indices != (): + raise ValueError( + "map_dict_of_named_arrays must be called with outer indices = ()." + ) + return expr.replace_if_different( + data=constantdict( + { + name: self.rec_w_passthru_indices(subexpr) + for name, subexpr in expr._data.items() + } + ) + ) + + +def push_index_to_materialized_nodes(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + """ + Returns a transformed version of *expr* such that all indexing nodes in + *expr* are propagated towards materialized arrays. Consequently, the + transformed *expr* will contain expressions of the form + ``x[idx1, idx2, ..., idxn]`` only if ``x`` is a materialized array. + + We consider an array as materialized when it is either an instance of + :class:`pytato.InputArgumentBase` or has been tagged with + :class:`pytato.tags.ImplStored`. + """ + from pytato.transform import deduplicate + + mapper = IndexPusher() + if isinstance(expr, Array): + return deduplicate( + mapper.rec_ary( + expr, tuple(NormalizedSlice(0, axis_len, 1) for axis_len in expr.shape) + ) + ) + else: + assert isinstance(expr, DictOfNamedArrays) + result_w_dups = mapper(expr, ()) + assert isinstance(result_w_dups, DictOfNamedArrays) + return deduplicate(result_w_dups) From 708e8ba4acc0095cfc5e7c7e9480d68cccf22567 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 9 Feb 2026 10:19:53 -0800 Subject: [PATCH 3/3] Test pt.push_index_to_materialized_nodes. --- test/test_transform.py | 1173 ++++++++++++++++++++++++++++++++++++++++ test/testlib.py | 43 +- 2 files changed, 1214 insertions(+), 2 deletions(-) create mode 100644 test/test_transform.py diff --git a/test/test_transform.py b/test/test_transform.py new file mode 100644 index 000000000..80fb21702 --- /dev/null +++ b/test/test_transform.py @@ -0,0 +1,1173 @@ +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2026 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import pytest +from numpy.random import default_rng +from testlib import assert_allclose_to_ref +from typing_extensions import override + +import pyopencl as cl +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests, # pyright: ignore[reportUnusedImport] +) + +import pytato as pt + + +class AssertOnlyMaterializedIndexees(pt.transform.CachedWalkMapper[[]]): + @override + def get_cache_key(self, expr: pt.transform.ArrayOrNames) -> pt.transform.CacheKeyT: + return (expr,) + + @override + def _map_index_base(self, expr: pt.IndexBase) -> None: + indexee = expr.array + assert ( + isinstance(indexee, pt.InputArgumentBase) + or pt.tags.ImplStored() in expr.tags + ) + self.rec(indexee) # do not recurse over indexes. + + +def assert_only_materialized_indexees(expr: pt.transform.ArrayOrNames) -> None: + mapper = AssertOnlyMaterializedIndexees() + mapper(expr) + + +def test_indirection_pusher_0(): + x = pt.make_placeholder("x", 10) + idx = pt.make_placeholder("idx", 1729, np.int32) + y = x[idx] + assert pt.push_index_to_materialized_nodes(y) == y + + +def test_indirection_pusher_1(ctx_factory): + x = pt.make_placeholder("x", 10) + idx = pt.make_placeholder("idx", 1729, np.int32) + y = (2 * x)[idx] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == 2 * (x[idx]) + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + {"x": rng.random(10), "idx": rng.integers(0, 10, 1729, np.int32)}, + ) + + +def test_indirection_pusher_2(ctx_factory): + x1 = pt.make_placeholder("x1", 10) + x2 = pt.make_placeholder("x2", 10) + idx = pt.make_placeholder("idx", 1729, np.int32) + y = (x1 * x2)[idx] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == (x1[idx] * x2[idx]) + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x1": rng.random(10), + "x2": rng.random(10), + "idx": rng.integers(0, 10, 1729, np.int32), + }, + ) + + +def test_indirection_pusher_3(ctx_factory): + x = pt.make_placeholder("x", 10) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 314, np.int32) + y = x[idx1][idx2] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx1[idx2]] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random(10), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 1729, 314, np.int32), + }, + ) + + +def test_indirection_pusher_4(ctx_factory): + x = pt.make_placeholder("x", 10) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 314, np.int32) + y = (2 * x[idx1])[idx2] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == 2 * x[idx1[idx2]] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random(10), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 1729, 314, np.int32), + }, + ) + + +def test_indirection_pusher_5(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + idx5 = pt.make_placeholder("idx5", 314, np.int32) + + y = x[:, idx1, idx2, :][idx3, idx4, idx5] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx3, idx1[idx4], idx2[idx4], idx5] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 10, 314, np.int32), + "idx4": rng.integers(0, 1729, 314, np.int32), + "idx5": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_6(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + idx5 = pt.make_placeholder("idx5", 314, np.int32) + + y = x[::2, idx1, idx2, ::3][idx3, idx4, idx5] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[2 * idx3, idx1[idx4], idx2[idx4], 3 * idx5] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 5, 314, np.int32), + "idx4": rng.integers(0, 1729, 314, np.int32), + "idx5": rng.integers(0, 4, 314, np.int32), + }, + ) + + +def test_indirection_pusher_7(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + + y = x[idx1, :, idx2][idx3, idx4] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx1[idx3], idx4, idx2[idx3]] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 1729, 314, np.int32), + "idx4": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_8(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + + y = x[idx1, ::2, idx2][idx3, idx4] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx1[idx3], 2 * idx4, idx2[idx3]] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 1729, 314, np.int32), + "idx4": rng.integers(0, 5, 314, np.int32), + }, + ) + + +def test_indirection_pusher_9(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + + y = x[idx1, idx2, ::2, ::3][idx3, :, idx4] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx1[idx3], idx2[idx3], ::2, 3 * idx4] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 1729, 314, np.int32), + "idx4": rng.integers(0, 4, 314, np.int32), + }, + ) + + +def test_indirection_pusher_10(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + # (_0, _1, _2) -> (idx1[_0], 2*_1, idx2[_0], _2) + # (_0, _1) -> (idx3[_0], 3*_1, idx4[_0]) + # Net: + # (_0, _1) -> (idx1[idx3[_0]], 6*_1, idx2[idx3[_0]], idx4[_0]) + + y = x[idx1, ::2, idx2][idx3, ::3, idx4] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx1[idx3], ::6, idx2[idx3], idx4] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 1729, 314, np.int32), + "idx4": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_11(ctx_factory): + x1 = pt.make_placeholder("x1", (10, 1, 10, 1)) + x2 = pt.make_placeholder("x2", (1, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + y1 = (x1 + x2)[:, idx1, idx2, :] + # (_0, _1, _2, _3) -> x1[_0, 0, _2, 0] + x2[0, _1, _2, _3] + # (_0, _1, _2) -> (_0, idx1[_1], idx2[_1], _2]) + # Net -> + # (_0, _1, _2) -> x1[_0, 0, idx2[_1], 0] + x2[0, idx1[_1], idx2[_1], _2] + y2 = x1[:, 0, idx2, :] + x2[:, idx1, idx2, :] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x1": rng.random((10, 1, 10, 1)), + "x2": rng.random((1, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + }, + ) + + +def test_indirection_pusher_12(ctx_factory): + x1 = pt.make_placeholder("x1", (10, 1, 10, 1)) + x2 = pt.make_placeholder("x2", (1, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + y1 = (x1 + x2)[idx1, :, idx2, :] + # (_0, _1, _2, _3) -> x1[_0, 0, _2, 0] + x2[0, _1, _2, _3] + # (_0, _1, _2) -> (idx1[_0], _1, idx2[_0], _2) + # Net-> + # (_0, _1, _2) -> x1[idx1[_0], 0, idx2[_0], 0] + x2[0, _1, idx2[_0], _2] + + y2 = x1[idx1, :, idx2, :] + x2[0, :, idx2, :] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x1": rng.random((10, 1, 10, 1)), + "x2": rng.random((1, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + }, + ) + + +@pytest.mark.xfail(reason="axis permutation not yet supported.") +def test_indirection_pusher_13(): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + y1 = pt.transpose(x, (0, 2, 3, 1))[idx1, :idx2, :] + # (_0, _1, _2, _3) -> (_0, _2, _3, _1) + # (_0, _1, _2) -> (idx1[_0], _1, idx2[_0], _2) + # Net-> + # (idx1[_0], idx2[_0], _2, _1) + y2 = pt.transpose(x[idx1, idx2], (0, 1, 3, 2)) + assert pt.push_index_to_materialized_nodes(y1) == y2 + + +@pytest.mark.xfail(reason="axis permutation not yet supported.") +def test_indirection_pusher_14(): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + y1 = pt.transpose(x, (0, 2, 3, 1))[idx1, :idx2, :] + # (_0, _1, _2, _3) -> (_0, _2, _3, _1) + # (_0, _1, _2) -> (idx1[_0], _1, idx2[_0], _2) + # Net-> + # (idx1[_0], idx2[_0], _2, _1) + y2 = pt.transpose(x[idx1, idx2], (0, 1, 3, 2)) + assert pt.push_index_to_materialized_nodes(y1) == y2 + + +def test_indirection_pusher_15(ctx_factory): + x = pt.make_placeholder("x", (10, 10)) + idx1 = pt.make_placeholder("idx1", 4, np.int32) + idx2 = pt.make_placeholder("idx2", (10, 4), np.int32) + idx3 = pt.make_placeholder("idx3", (1, 10, 10), np.int32) + idx4 = pt.make_placeholder("idx4", (10, 10, 10), np.int32) + + y = x[idx1, idx2][idx3, idx4] + y_prime = pt.push_index_to_materialized_nodes(y) + assert y_prime == x[idx1[idx4], idx2[idx3, idx4]] + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref( + y_prime, + y, + cq, + { + "x": rng.random((10, 10)), + "idx1": rng.integers(0, 10, 4, np.int32), + "idx2": rng.integers(0, 10, (10, 4), np.int32), + "idx3": rng.integers(0, 10, (1, 10, 10), np.int32), + "idx4": rng.integers(0, 4, (10, 10, 10), np.int32), + }, + ) + + +def test_indirection_pusher_16(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10)) + idx1 = pt.make_placeholder("idx1", (4, 1, 4), np.int32) + idx2 = pt.make_placeholder("idx2", (10, 4), np.int32) + idx3 = pt.make_placeholder("idx3", (10, 4), np.int32) + idx4 = pt.make_placeholder("idx4", (10, 1), np.int32) + idx5 = pt.make_placeholder("idx5", (10, 10), np.int32) + y1 = x[idx1, idx2, idx3][idx4, 2:5, idx5] + y2 = x[ + idx1[idx4, :, idx5], + pt.transpose(idx2[2:5, idx5], (1, 2, 0)), + pt.transpose(idx3[2:5, idx5], (1, 2, 0)), + ] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10)), + "idx1": rng.integers(0, 10, (4, 1, 4), np.int32), + "idx2": rng.integers(0, 10, (10, 4), np.int32), + "idx3": rng.integers(0, 10, (10, 4), np.int32), + "idx4": rng.integers(0, 4, (10, 1), np.int32), + "idx5": rng.integers(0, 4, (10, 10), np.int32), + }, + ) + + +def test_indirection_pusher_17(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + y1 = x[:, idx1, :, idx2][:, idx3, idx4] + y2 = x[idx3, pt.expand_dims(idx1, 1), idx4, pt.expand_dims(idx2, 1)] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 10, 314, np.int32), + "idx4": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_18(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + y1 = x[:, idx1, idx2, :][:, idx3, idx4] + y2 = x[:, idx1[idx3], idx2[idx3], idx4] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 1729, 314, np.int32), + "idx4": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_19(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + idx4 = pt.make_placeholder("idx4", 314, np.int32) + y1 = x[:, idx1, :, idx2, :][:, :, idx3, idx4] + y2 = pt.transpose( + x[:, pt.expand_dims(idx1, 1), idx3, pt.expand_dims(idx2, 1), idx4], (1, 0, 2) + ) + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 10, 314, np.int32), + "idx4": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_20(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", (271, 172, 31), np.int32) + idx2 = pt.make_placeholder("idx2", (172, 31), np.int32) + idx3 = pt.make_placeholder("idx3", 6, np.int32) + idx4 = pt.make_placeholder("idx4", (10, 6), np.int32) + y1 = x[:, idx1, :, idx2][:, :, :, idx3, idx4] + y2 = x[idx3, pt.expand_dims(idx1, (3, 4)), idx4, pt.expand_dims(idx2, (2, 3))] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, (271, 172, 31), np.int32), + "idx2": rng.integers(0, 10, (172, 31), np.int32), + "idx3": rng.integers(0, 10, 6, np.int32), + "idx4": rng.integers(0, 10, (10, 6), np.int32), + }, + ) + + +def test_indirection_pusher_21(ctx_factory): + x1 = pt.make_placeholder("x1", (10, 1, 10)) + x2 = pt.make_placeholder("x2", (10, 10, 10)) + idx1 = pt.make_placeholder("idx1", (6, 1, 1), np.int32) + idx2 = pt.make_placeholder("idx2", (6, 1), np.int32) + idx3 = pt.make_placeholder("idx3", (6), np.int32) + y1 = (x1 + x2)[idx1, idx2, idx3] + y2 = x1[idx1, 0, idx3] + x2[idx1, idx2, idx3] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x1": rng.random((10, 1, 10)), + "x2": rng.random((10, 10, 10)), + "idx1": rng.integers(0, 10, (6, 1, 1), np.int32), + "idx2": rng.integers(0, 10, (6, 1), np.int32), + "idx3": rng.integers(0, 10, 6, np.int32), + }, + ) + + +def test_indirection_pusher_22(ctx_factory): + x1 = pt.make_placeholder("x1", (10, 10)) + x2 = pt.make_placeholder("x2", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 1729, np.int32) + y1 = (x1 + x2)[idx1, :, idx2, idx3] + y2 = pt.expand_dims(x1[idx2, idx3], 1) + x2[idx1, :, idx2, idx3] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x1": rng.random((10, 10)), + "x2": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 10, 1729, np.int32), + }, + ) + + +def test_indirection_pusher_23(ctx_factory): + x1 = pt.make_placeholder("x1", (10, 10, 10)) + x2 = pt.make_placeholder("x2", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 1729, np.int32) + y1 = (x1 + x2)[idx1, :, idx2, idx3] + y2 = pt.transpose(x1[:, idx2, idx3], (1, 0)) + x2[idx1, :, idx2, idx3] + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x1": rng.random((10, 10, 10)), + "x2": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 10, 1729, np.int32), + }, + ) + + +def test_indirection_pusher_24(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 314, np.int32) + idx3 = pt.make_placeholder("idx3", 314, np.int32) + y1 = x[:, :, idx1, :][:, idx2, :, idx3] + y2 = pt.transpose( + x[:, pt.expand_dims(idx2, 1), idx1, pt.expand_dims(idx3, 1)], (1, 0, 2) + ) + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 314, np.int32), + "idx3": rng.integers(0, 10, 314, np.int32), + }, + ) + + +def test_indirection_pusher_25(ctx_factory): + x1 = pt.make_placeholder("x1", (10, 10, 10)) + x2 = pt.make_placeholder("x2", (10, 10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", 1729, np.int32) + idx2 = pt.make_placeholder("idx2", 1729, np.int32) + idx3 = pt.make_placeholder("idx3", 1729, np.int32) + y1 = (x1 + x2)[:, idx1, :, idx2, idx3] + y2 = ( + pt.transpose(pt.expand_dims(x1[:, idx2, idx3], 2), (1, 2, 0)) + + x2[:, idx1, :, idx2, idx3] + ) + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x1": rng.random((10, 10, 10)), + "x2": rng.random((10, 10, 10, 10, 10)), + "idx1": rng.integers(0, 10, 1729, np.int32), + "idx2": rng.integers(0, 10, 1729, np.int32), + "idx3": rng.integers(0, 10, 1729, np.int32), + }, + ) + + +def test_indirection_pusher_26(ctx_factory): + x = pt.make_placeholder("x", (10, 10, 10, 10)) + idx1 = pt.make_placeholder("idx1", (4, 1, 4), np.int32) + idx2 = pt.make_placeholder("idx2", (10, 4), np.int32) + idx3 = pt.make_placeholder("idx3", (10, 4), np.int32) + idx4 = pt.make_placeholder("idx4", (10, 1), np.int32) + idx5 = pt.make_placeholder("idx5", (10, 10), np.int32) + y1 = x[:, idx1, idx2, idx3][:, idx4, 2:5, idx5] + # In the computation of y1. + # tmp1[_0, _1, _2, _3] = x[_0, idx1[_1, 0, _3], idx2[_2, _3], idx3[_2, _3]] + # y1[_0, _1, _2, _3] = tmp1[_2, idx4[_0, 0], _3+2, idx5[_0, _1]] + # Net + # y1[_0, _1, _2, _3] = + # x[_2, + # idx1[idx4[_0, 0], 0, idx5[_0, _1]], + # idx2[_3+2, idx5[_0, _1]], + # idx3[_3 +2, idx5[_0, _1]]] + y2 = pt.transpose( + x[ + :, + idx1[idx4, :, idx5], + pt.transpose(idx2[2:5, idx5], (1, 2, 0)), + pt.transpose(idx3[2:5, idx5], (1, 2, 0)), + ], + (1, 2, 0, 3), + ) + assert pt.push_index_to_materialized_nodes(y1) == y2 + + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + assert_only_materialized_indexees(y2) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y1) + + assert_allclose_to_ref( + y2, + y1, + cq, + { + "x": rng.random((10, 10, 10, 10)), + "idx1": rng.integers(0, 10, (4, 1, 4), np.int32), + "idx2": rng.integers(0, 10, (10, 4), np.int32), + "idx3": rng.integers(0, 10, (10, 4), np.int32), + "idx4": rng.integers(0, 4, (10, 1), np.int32), + "idx5": rng.integers(0, 4, (10, 10), np.int32), + }, + ) + + +def dgfem_flux( + u, + map_, + map_0, + map_1, + map_2, + map_3, + map_4, + map_5, + map_6, + map_7, + map_8, + map_9, + map_10, + map_11, + map_12, + map_13, +): + tmp_3 = u[map_.reshape((192, 1)), map_0[map_1]] + tmp_2 = tmp_3 - tmp_3 + tmp_1 = tmp_2[map_2.reshape((1536, 1)), map_3[map_4]] + tmp_11 = u[map_5.reshape((1344, 1)), map_6[map_7]] + tmp_9 = tmp_11[map_8.reshape((1344, 1)), map_9[map_10]] - tmp_11 + tmp_8 = tmp_9[map_11.reshape((1536, 1)), map_12[map_13]] + tmp_0 = tmp_1 + tmp_8 + return tmp_0.reshape((4, 384, 15)) + + +def test_indirection_pusher_27(ctx_factory): + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + dgfem_flux_inputs = { + "u": rng.random((384, 35), np.float64), + "map_": rng.integers(0, 384, 192, np.int32), + "map_0": rng.integers(0, 35, (4, 15), np.int64), + "map_1": rng.integers(0, 4, 192, np.int8), + "map_2": rng.integers(0, 192, 1536, np.int32), + "map_3": rng.integers(0, 15, (1, 15), np.int64), + "map_4": rng.integers(0, 1, 1536, np.int8), + "map_5": rng.integers(0, 1536, 1344, np.int32), + "map_6": rng.integers(0, 35, (4, 15), np.int64), + "map_7": rng.integers(0, 4, 1344, np.int8), + "map_8": rng.integers(0, 1344, 1344, np.int32), + "map_9": rng.integers(0, 15, (3, 15), np.int64), + "map_10": rng.integers(0, 3, 1344, np.int8), + "map_11": rng.integers(0, 1344, 1536, np.int32), + "map_12": rng.integers(0, 15, (1, 15), np.int64), + "map_13": rng.integers(0, 1, 1536, np.int8), + } + + y = dgfem_flux( + **{ + k: pt.make_placeholder(k, v.shape, v.dtype) + for k, v in dgfem_flux_inputs.items() + } + ) + + y_prime = pt.push_index_to_materialized_nodes(y) + + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref(y, y_prime, cq, dgfem_flux_inputs) + + +def wave_3d_p4_all_fluxes( + *, + u, + v_0, + v_1, + v_2, + map_, + map_0, + map_1, + map_2, + map_3, + map_4, + map_5, + map_6, + map_7, + map_8, + map_9, + map_10, + map_11, + map_12, + map_13, + map_14, + map_15, + map_16, + map_17, + map_18, + map_19, + map_20, + map_21, +): + tmp_4 = pt.reshape(map_, (1536, 1)) + tmp_13 = pt.reshape(map_0, (192, 1)) + tmp_14 = map_1[map_2,] + tmp_12 = u[tmp_13, tmp_14] + tmp_11 = 0.0 + tmp_12 + tmp_15 = -1.0 * tmp_11 + tmp_10 = tmp_11 + tmp_15 + tmp_9 = 0.5 * tmp_10 + tmp_8 = tmp_9 * map_3 + tmp_23 = v_0[tmp_13, tmp_14] + tmp_22 = 0.0 + tmp_23 + tmp_21 = tmp_22 - tmp_22 + tmp_20 = tmp_21 * map_4 + tmp_27 = v_1[tmp_13, tmp_14] + tmp_26 = 0.0 + tmp_27 + tmp_25 = tmp_26 - tmp_26 + tmp_24 = tmp_25 * map_3 + tmp_19 = tmp_20 + tmp_24 + tmp_31 = v_2[tmp_13, tmp_14] + tmp_30 = 0.0 + tmp_31 + tmp_29 = tmp_30 - tmp_30 + tmp_28 = tmp_29 * map_5 + tmp_18 = tmp_19 + tmp_28 + tmp_17 = 0.5 * tmp_18 + tmp_16 = tmp_17 * map_3 + tmp_7 = tmp_8 + tmp_16 + tmp_6 = 1.0 * tmp_7 + tmp_32 = pt.reshape(map_6, (1536, 1)) + tmp_33 = map_7[map_8,] + tmp_5 = tmp_6[tmp_32, tmp_33] + tmp_3 = pt.where(tmp_4, tmp_5, 0) + tmp_2 = 0.0 + tmp_3 + tmp_37 = pt.reshape(map_9, (1536, 1)) + tmp_46 = pt.reshape(map_10, (1344, 1)) + tmp_47 = map_11[map_12,] + tmp_45 = u[tmp_46, tmp_47] + tmp_44 = 0.0 + tmp_45 + tmp_50 = pt.reshape(map_13, (1344, 1)) + tmp_51 = map_14[map_15,] + tmp_49 = tmp_44[tmp_50, tmp_51] + tmp_48 = 0.0 + tmp_49 + tmp_43 = tmp_44 + tmp_48 + tmp_42 = 0.5 * tmp_43 + tmp_41 = tmp_42 * map_16 + tmp_61 = v_0[tmp_46, tmp_47] + tmp_60 = 0.0 + tmp_61 + tmp_59 = tmp_60[tmp_50, tmp_51] + tmp_58 = 0.0 + tmp_59 + tmp_57 = tmp_58 - tmp_60 + tmp_56 = tmp_57 * map_17 + tmp_67 = v_1[tmp_46, tmp_47] + tmp_66 = 0.0 + tmp_67 + tmp_65 = tmp_66[tmp_50, tmp_51] + tmp_64 = 0.0 + tmp_65 + tmp_63 = tmp_64 - tmp_66 + tmp_62 = tmp_63 * map_16 + tmp_55 = tmp_56 + tmp_62 + tmp_73 = v_2[tmp_46, tmp_47] + tmp_72 = 0.0 + tmp_73 + tmp_71 = tmp_72[tmp_50, tmp_51] + tmp_70 = 0.0 + tmp_71 + tmp_69 = tmp_70 - tmp_72 + tmp_68 = tmp_69 * map_18 + tmp_54 = tmp_55 + tmp_68 + tmp_53 = 0.5 * tmp_54 + tmp_52 = tmp_53 * map_16 + tmp_40 = tmp_41 + tmp_52 + tmp_39 = 1.0 * tmp_40 + tmp_74 = pt.reshape(map_19, (1536, 1)) + tmp_75 = map_20[map_21,] + tmp_38 = tmp_39[tmp_74, tmp_75] + tmp_36 = pt.where(tmp_37, tmp_38, 0) + tmp_35 = 0.0 + tmp_36 + tmp_34 = 0.0 + tmp_35 + tmp_1 = tmp_2 + tmp_34 + tmp_0 = pt.reshape(tmp_1, (4, 384, 15)) + tmp_87 = tmp_22 + tmp_22 + tmp_86 = 0.5 * tmp_87 + tmp_85 = tmp_86 * map_4 + tmp_90 = tmp_26 + tmp_26 + tmp_89 = 0.5 * tmp_90 + tmp_88 = tmp_89 * map_3 + tmp_84 = tmp_85 + tmp_88 + tmp_93 = tmp_30 + tmp_30 + tmp_92 = 0.5 * tmp_93 + tmp_91 = tmp_92 * map_5 + tmp_83 = tmp_84 + tmp_91 + tmp_95 = tmp_15 - tmp_11 + tmp_94 = 0.5 * tmp_95 + tmp_82 = tmp_83 + tmp_94 + tmp_81 = 1.0 * tmp_82 + tmp_80 = tmp_81[tmp_32, tmp_33] + tmp_79 = pt.where(tmp_4, tmp_80, 0) + tmp_78 = 0.0 + tmp_79 + tmp_106 = tmp_60 + tmp_58 + tmp_105 = 0.5 * tmp_106 + tmp_104 = tmp_105 * map_17 + tmp_109 = tmp_66 + tmp_64 + tmp_108 = 0.5 * tmp_109 + tmp_107 = tmp_108 * map_16 + tmp_103 = tmp_104 + tmp_107 + tmp_112 = tmp_72 + tmp_70 + tmp_111 = 0.5 * tmp_112 + tmp_110 = tmp_111 * map_18 + tmp_102 = tmp_103 + tmp_110 + tmp_114 = tmp_48 - tmp_44 + tmp_113 = 0.5 * tmp_114 + tmp_101 = tmp_102 + tmp_113 + tmp_100 = 1.0 * tmp_101 + tmp_99 = tmp_100[tmp_74, tmp_75] + tmp_98 = pt.where(tmp_37, tmp_99, 0) + tmp_97 = 0.0 + tmp_98 + tmp_96 = 0.0 + tmp_97 + tmp_77 = tmp_78 + tmp_96 + tmp_76 = pt.reshape(tmp_77, (4, 384, 15)) + tmp_122 = tmp_9 * map_5 + tmp_123 = tmp_17 * map_5 + tmp_121 = tmp_122 + tmp_123 + tmp_120 = 1.0 * tmp_121 + tmp_119 = tmp_120[tmp_32, tmp_33] + tmp_118 = pt.where(tmp_4, tmp_119, 0) + tmp_117 = 0.0 + tmp_118 + tmp_130 = tmp_42 * map_18 + tmp_131 = tmp_53 * map_18 + tmp_129 = tmp_130 + tmp_131 + tmp_128 = 1.0 * tmp_129 + tmp_127 = tmp_128[tmp_74, tmp_75] + tmp_126 = pt.where(tmp_37, tmp_127, 0) + tmp_125 = 0.0 + tmp_126 + tmp_124 = 0.0 + tmp_125 + tmp_116 = tmp_117 + tmp_124 + tmp_115 = pt.reshape(tmp_116, (4, 384, 15)) + tmp_139 = tmp_9 * map_4 + tmp_140 = tmp_17 * map_4 + tmp_138 = tmp_139 + tmp_140 + tmp_137 = 1.0 * tmp_138 + tmp_136 = tmp_137[tmp_32, tmp_33] + tmp_135 = pt.where(tmp_4, tmp_136, 0) + tmp_134 = 0.0 + tmp_135 + tmp_147 = tmp_42 * map_17 + tmp_148 = tmp_53 * map_17 + tmp_146 = tmp_147 + tmp_148 + tmp_145 = 1.0 * tmp_146 + tmp_144 = tmp_145[tmp_74, tmp_75] + tmp_143 = pt.where(tmp_37, tmp_144, 0) + tmp_142 = 0.0 + tmp_143 + tmp_141 = 0.0 + tmp_142 + tmp_133 = tmp_134 + tmp_141 + tmp_132 = pt.reshape(tmp_133, (4, 384, 15)) + return {"flux_0": tmp_0, "flux_1": tmp_76, "flux_2": tmp_115, "flux_3": tmp_132} + + +def test_indirection_pusher_28(ctx_factory): + rng = default_rng(42) + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + np_fields = { + "u": rng.random((384, 35)), + "v_0": rng.random((384, 35)), + "v_1": rng.random((384, 35)), + "v_2": rng.random((384, 35)), + } + np_indirections = { + "map_": rng.integers(0, 2, 1536, np.int8), + "map_0": rng.integers(0, 384, 192, np.int32), + "map_1": rng.integers(0, 35, (4, 15), np.int64), + "map_2": rng.integers(0, 4, 192, np.int8), + "map_3": rng.random((192, 1)), + "map_4": rng.random((192, 1)), + "map_5": rng.random((192, 1)), + "map_6": rng.integers(0, 192, 1536, np.int32), + "map_7": rng.integers(0, 15, (1, 15), np.int64), + "map_8": rng.integers(0, 1, 1536, np.int8), + "map_9": rng.integers(0, 2, 1536, np.int8), + "map_10": rng.integers(0, 384, 1344, np.int32), + "map_11": rng.integers(0, 35, (4, 15), np.int64), + "map_12": rng.integers(0, 4, 1344, np.int8), + "map_13": rng.integers(0, 1344, 1344, np.int32), + "map_14": rng.integers(0, 15, (3, 15), np.int64), + "map_15": rng.integers(0, 3, 1344, np.int8), + "map_16": rng.random((1344, 1)), + "map_17": rng.random((1344, 1)), + "map_18": rng.random((1344, 1)), + "map_19": rng.integers(0, 1344, 1536, np.int32), + "map_20": rng.integers(0, 15, (1, 15), np.int64), + "map_21": rng.integers(0, 1, 1536, np.int8), + } + pt_fields = { + name: pt.make_placeholder(name, np_field.shape, np_field.dtype) + for name, np_field in np_fields.items() + } + pt_indirections = { + name: pt.make_data_wrapper(np_indirection) + for name, np_indirection in np_indirections.items() + } + + y = pt.make_dict_of_named_arrays( + wave_3d_p4_all_fluxes(**pt_fields, **pt_indirections) + ) + + y_prime = pt.push_index_to_materialized_nodes(y) + + assert_only_materialized_indexees(y_prime) + with pytest.raises(AssertionError): + assert_only_materialized_indexees(y) + + assert_allclose_to_ref(y, y_prime, cq, np_fields) diff --git a/test/testlib.py b/test/testlib.py index 0b587f400..a0099e04f 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -10,7 +10,8 @@ from pytools.tag import Tag import pytato as pt -from pytato.transform import Mapper +from pytato.array import Array +from pytato.transform import ArrayOrNamesTc, Mapper if TYPE_CHECKING: @@ -20,7 +21,6 @@ import pyopencl as cl from pytato.array import ( - Array, AxisPermutation, Concatenate, DataWrapper, @@ -91,6 +91,45 @@ def assert_allclose_to_numpy(expr: Array, queue: cl.CommandQueue, np.testing.assert_allclose(np_result, pt_result, rtol=rtol) + +def assert_allclose_to_ref( + expr: ArrayOrNamesTc, + ref: ArrayOrNamesTc, + queue: cl.CommandQueue, + parameters: dict[str, Any] | None = None, + rtol: float = 1e-7, +) -> None: + """ + Raises an :class:`AssertionError`, if there is a value discrepancy between + *expr* and *ref* on evaluation with the placeholders values as *parameters*. + + :arg queue: An instance of :class:`pyopencl.CommandQueue` to which the + generated kernel must be enqueued. + """ + if parameters is None: + parameters = {} + if isinstance(expr, Array): + assert isinstance(ref, Array) + expr_dict = pt.make_dict_of_named_arrays({"_pt_out": expr}) + ref_dict = pt.make_dict_of_named_arrays({"_pt_out": ref}) + else: + expr_dict = expr + ref_dict = ref + + ref_prog = pt.generate_loopy(ref_dict) + prog = pt.generate_loopy(expr_dict) + + _evt, ref_evaled = ref_prog(queue, **parameters) + _evt, expr_evaled = prog(queue, **parameters) + + assert set(ref_evaled.keys()) == set(expr_evaled.keys()) + for name, subref in ref_evaled.items(): + subexpr = expr_evaled[name] + assert subref.shape == subexpr.shape + assert subref.dtype == subexpr.dtype + np.testing.assert_allclose(subexpr, subref, rtol=rtol) + + # }}}