diff --git a/pytato/__init__.py b/pytato/__init__.py index 88ce4e4e9..654b846e7 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -115,6 +115,11 @@ def set_debug_enabled(flag: bool) -> None: from pytato.transform.metadata import unify_axes_tags from pytato.function import trace_call from pytato.array import set_traceback_tag_enabled +from pytato.transform.indirections import ( + decouple_multi_axis_indirections_into_single_axis_indirections, + push_axis_indirections_towards_materialized_nodes, + fold_constant_indirections, +) __all__ = ( "dtype", @@ -183,6 +188,10 @@ def set_debug_enabled(flag: bool) -> None: "unify_axes_tags", + "decouple_multi_axis_indirections_into_single_axis_indirections", + "push_axis_indirections_towards_materialized_nodes", + "fold_constant_indirections", + # sub-modules "analysis", "tags", "transform", "function", diff --git a/pytato/transform/indirections.py b/pytato/transform/indirections.py new file mode 100644 index 000000000..232c03a1a --- /dev/null +++ b/pytato/transform/indirections.py @@ -0,0 +1,995 @@ +__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +from typing import (Any, Dict, Mapping, Tuple, TypeAlias, Iterable, + FrozenSet, Union, Set, List, Optional, Callable) +from pytato.array import (Array, InputArgumentBase, DictOfNamedArrays, + IndexLambda, ShapeComponent, + NormalizedSlice, + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + BasicIndex, Reshape, Roll, Einsum, AxisPermutation, + Stack, Concatenate, DataWrapper, IndexBase, Placeholder) + +from pytato.tags import ImplStored +from pytato.transform import (CachedMapper, CopyMapper, Mapper, MappedT, + ArrayOrNames, CombineMapper) +from dataclasses import dataclass +import pytato.scalar_expr as scalar_expr +import pymbolic.primitives as prim +from immutables import Map +from pytato.utils import are_shape_components_equal + +if sys.version >= (3, 11): + zip_equal = lambda *_args: zip(*_args, strict=True) +else: + from more_itertools import zip_equal + +_ComposedIndirectionT: TypeAlias = Tuple[Array, ...] +IndexT: TypeAlias = Union[Array, NormalizedSlice] + + +def _is_materialized(expr: Array) -> bool: + """ + Returns true if an array is materialized. An array is considered to be + materialized if it is either a :class:`pytato.array.InputArgumentBase` or + is tagged with :class:`pytato.tags.ImplStored`. + """ + return (isinstance(expr, InputArgumentBase) + or bool(expr.tags_of_type(ImplStored))) + + +def _is_trivial_slice(dim: ShapeComponent, slice_: IndexT) -> bool: + """ + Returns *True* only if *slice_* indexes an entire axis of shape *dim* with + a step of 1. + """ + return (isinstance(slice_, NormalizedSlice) + and slice_.step == 1 + and are_shape_components_equal(slice_.start, 0) + and are_shape_components_equal(slice_.stop, dim)) + + +def _take_along_axis(ary: Array, iaxis: int, idxs: IndexStackT) -> Array: + """ + Returns an indexed version of *ary* with *iaxis*-th axis indexed with *idxs*. + """ + # {{{ compose the slices + + composed_slice: Union[Array, slice] = slice(0, ary.shape[iaxis], 1) + + for idx in idxs[::-1]: + if isinstance(composed_slice, slice): + if isinstance(idx, NormalizedSlice): + new_start = (composed_slice.start + + composed_slice.step * idx.start) + if composed_slice.step > 0: + new_stop = (composed_slice.start + + composed_slice.step * (idx.stop-1) + + 1) + else: + new_stop = (composed_slice.start + + composed_slice.step * (idx.stop+1) + - 1) + + new_step = composed_slice.step * idx.step + + composed_slice = slice(new_start, new_stop, new_step) + else: + assert isinstance(idx, Array) + if composed_slice.step > 0: + if (composed_slice.step == 1 + and are_shape_components_equal(composed_slice.start, 0)): + # minor optimization to emit cleaner DAGs when possible + composed_slice = idx + else: + composed_slice = (composed_slice.step * idx + + composed_slice.start) + else: + composed_slice = ((composed_slice.stop - 1) + + composed_slice.step * idx) + else: + assert isinstance(composed_slice, Array) + if isinstance(idx, NormalizedSlice): + composed_slice = composed_slice[slice(idx.start, idx.stop, idx.step)] + else: + assert isinstance(idx, Array) + composed_slice = composed_slice[idx] + + # }}} + + if (isinstance(composed_slice, slice) + and _is_trivial_slice(ary.shape[iaxis], + NormalizedSlice(composed_slice.start, + composed_slice.stop, + composed_slice.step))): + return ary + else: + return ary[(slice(None), )*iaxis + (composed_slice, )] + + +@dataclass(frozen=True) +class _BindingAxisGetterAcc: + r""" + Return type of :class:`_BindingAxisGetter` recording how a particular axis + is indexed in a :class:`pytato.array.IndexLambda` for a particular binding. + """ + + +@dataclass(frozen=True) +class _InvariantAxis(_BindingAxisGetterAcc): + r""" + Records that the array :attr:`_BindingAxisGetter.bnd_name`\ 's access in a + :class:`~pytato.scalar_expr.ScalarExpression` is invariant along the + :attr:`_BindingAxisGetter.iout_axis` axis. + """ + + +@dataclass(frozen=True) +class _BindingNotAccessed(_BindingAxisGetterAcc): + """ + Records that the array, :attr:`_BindingAxisGetterAcc.bnd_name`, is not + accessed in a :class:`~pytato.scalar_expr.ScalarExpression`. + """ + pass + + +@dataclass(frozen=True) +class _SingleAxisDependentAccess(_BindingAxisGetterAcc): + """ + Records that the array's *iaxis*-th index is dependent only a single output + axis of an :class:`~pytato.array.IndexLambda`. + """ + iaxis: int + + +@dataclass(frozen=True) +class _IllegalAxisAccess(_BindingAxisGetterAcc): + """ + Records that the access :attr:`_BindingAxisGetter.iout_axis` does not allow + reindexing without modifying :class:`pytato.array.IndexLambda.expr`. + """ + + +class _BindingAxisGetter(scalar_expr.CombineMapper): + """ + Mapper that returns how the binding named :attr:`bnd_name` is dependent on + the index :attr:`iout_axis`. + """ + def __init__(self, iout_axis: int, bnd_name: str): + self.iout_axis = iout_axis + self.bnd_name = bnd_name + super().__init__() + + def combine(self, + values: Iterable[_BindingAxisGetterAcc]) -> _BindingAxisGetterAcc: + + values = list(values) # avoid running into generators + if any(isinstance(val, _IllegalAxisAccess) for val in values): + return _IllegalAxisAccess() + + axis_dependent_values = {val + for val in values + if isinstance(val, _SingleAxisDependentAccess)} + invariant_axis_values = {val + for val in values + if isinstance(val, _InvariantAxis)} + + if len(invariant_axis_values | axis_dependent_values) == 0: + return _BindingNotAccessed() + elif len(invariant_axis_values | axis_dependent_values) == 1: + combined_value, = invariant_axis_values | axis_dependent_values + return combined_value + else: + return _IllegalAxisAccess() + + def map_subscript(self, expr: prim.Subscript) -> _BindingAxisGetterAcc: + from pytato.scalar_expr import get_dependencies + + if expr.aggregate == prim.Variable(self.bnd_name): + if f"_{self.iout_axis}" not in get_dependencies(expr.index_tuple): + return _InvariantAxis() + + values = [] + for i_idx, idx in enumerate(expr.index_tuple): + if get_dependencies(idx) == frozenset([f"_{self.iout_axis}"]): + values.append(_SingleAxisDependentAccess(i_idx)) + else: + values.append(self.rec(idx)) + return self.combine(values) + else: + return self.combine([self.rec(idx) + for idx in expr.index_tuple + if self.bnd_name in get_dependencies(idx)]) + + def map_variable(self, expr: prim.Variable) -> _BindingAxisGetterAcc: + if expr.name == f"_{self.iout_axis}": + return _IllegalAxisAccess() + + return _BindingNotAccessed() + + def map_constant(self, + expr: scalar_expr.ScalarExpression) -> _BindingAxisGetterAcc: + return _BindingNotAccessed() + + map_nan = map_constant + + +def _get_iaxis_in_binding(expr: scalar_expr.ScalarExpression, + iaxis: int, + bnd_name: str) -> _BindingAxisGetterAcc: + mapper = _BindingAxisGetter(iaxis, bnd_name) + # type-ignore-reason: pymbolic mapper types are imprecise. + return mapper(expr) # type: ignore[no-any-return] + + +class _LegallyAxisReorderingFinder(CachedMapper[FrozenSet[int]]): + """ + Maps a :class:`pytato.array` to it's set of axes along which indirections + can be propagated. We use the following rules to get the legally + reorderable axes of an array: + + - All axes of a materialized array are reorderable. + - The i-th axis of an :class:`~pytato.array.IndexLambda` is reorderable + only if all its bindings either do not index using the i-th index, OR, + every binding has a unique axis which indexes using the i-th index and + that axis in binding is reorderable. + + These rules legally allow propagating indirections that are applied to an + index lambda's to its bindings axes without altering the index lambda's + scalar expression. + """ + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> Any: + raise ValueError("_LegallyAxisReorderingFinder is a valid operation" + " only for arrays") + + def _map_materialized(self, expr: Array) -> FrozenSet[int]: + return frozenset(range(expr.ndim)) + + map_placeholder = _map_materialized + map_data_wrapper = _map_materialized + + def _map_index_lambda_like(self, expr: IndexLambda) -> FrozenSet[int]: + if _is_materialized(expr): + return self._map_materialized(expr) + + from pytato.transform.lower_to_index_lambda import to_index_lambda + idx_lambda = to_index_lambda(expr) + legal_orderings: Set[int] = set() + rec_bindings = {name: self.rec(bnd) + for name, bnd in idx_lambda.bindings.items()} + + for idim in range(idx_lambda.ndim): + bnd_name_to_iaxis = {name: _get_iaxis_in_binding(idx_lambda.expr, + idim, + name) + for name in idx_lambda.bindings} + is_reordering_idim_legal = all( + ((isinstance(ibnd_axis, _SingleAxisDependentAccess) + and ibnd_axis.iaxis in rec_bindings[name]) + or isinstance(ibnd_axis, (_InvariantAxis, _BindingNotAccessed))) + for name, ibnd_axis in bnd_name_to_iaxis.items() + ) + if is_reordering_idim_legal: + legal_orderings.add(idim) + + return frozenset(legal_orderings) + + map_index_lambda = _map_index_lambda_like + map_stack = _map_index_lambda_like + map_concatenate = _map_index_lambda_like + map_einsum = _map_index_lambda_like + map_roll = _map_index_lambda_like + map_basic_index = _map_index_lambda_like + map_reshape = _map_index_lambda_like + map_axis_permutation = _map_index_lambda_like + map_basic_index = _map_index_lambda_like + + # {{{ advanced indexing nodes are special -> requires additional checks + # on the indexers such as single-axis reordering + + def map_contiguous_advanced_index(self, + expr: AdvancedIndexInContiguousAxes + ) -> FrozenSet[int]: + if _is_materialized(expr): + return self._map_materialized(expr) + + from pytato.utils import (get_shape_after_broadcasting, + partition) + legal_orderings: Set[int] = set() + array_legal_orderings = self.rec(expr.array) + + i_adv_indices, i_basic_indices = partition(lambda idx: isinstance( + expr.indices[idx], + NormalizedSlice), + range(len(expr.indices))) + + pre_basic_indices, post_basic_indices = partition( + lambda idx: idx > i_adv_indices[0], + i_basic_indices, + ) + ary_indices = tuple(i_idx + for i_idx, idx in enumerate(expr.indices) + if isinstance(idx, Array)) + + assert all(i_adv_indices[0] > i_basic_idx + for i_basic_idx in pre_basic_indices) + assert all(i_adv_indices[-1] < i_basic_idx + for i_basic_idx in post_basic_indices) + + adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx] + for i_idx in i_adv_indices]) + + iout_axis = 0 + + for i_idx in pre_basic_indices: + if i_idx in array_legal_orderings: + legal_orderings.add(iout_axis) + iout_axis += 1 + + if len(adv_idx_shape) != 1 or len(ary_indices) != 1: + # cannot reorder these axes + iout_axis += len(adv_idx_shape) + else: + if ary_indices[0] in array_legal_orderings: + legal_orderings.add(iout_axis) + iout_axis += 1 + + for i_idx in post_basic_indices: + if i_idx in array_legal_orderings: + legal_orderings.add(iout_axis) + iout_axis += 1 + + return frozenset(legal_orderings) + + def map_non_contiguous_advanced_index(self, + expr: AdvancedIndexInNoncontiguousAxes + ) -> FrozenSet[int]: + if _is_materialized(expr): + return self._map_materialized(expr) + + from pytato.utils import (get_shape_after_broadcasting, + partition) + legal_orderings: Set[int] = set() + array_legal_orderings = self.rec(expr.array) + + i_adv_indices, i_basic_indices = partition(lambda idx: isinstance( + expr.indices[idx], + NormalizedSlice), + range(len(expr.indices))) + + ary_indices = tuple(i_idx + for i_idx, idx in enumerate(expr.indices) + if isinstance(idx, Array)) + + adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx] + for i_idx in i_adv_indices]) + + iout_axis = 0 + + if len(adv_idx_shape) != 1 or len(ary_indices) != 1: + # cannot reorder these axes + iout_axis += len(adv_idx_shape) + else: + if ary_indices[0] in array_legal_orderings: + legal_orderings.add(iout_axis) + iout_axis += 1 + + for i_idx in i_basic_indices: + if i_idx in array_legal_orderings: + legal_orderings.add(iout_axis) + iout_axis += 1 + + return frozenset(legal_orderings) + + # }}} + + +def _get_iout_axis_to_binding_axis( + expr: Array) -> Map[Array, Map[int, int]]: + from pytato.transform.lower_to_index_lambda import to_index_lambda + idx_lambda = to_index_lambda(expr) + + result: Dict[Array, Dict[int, int]] = { + bnd: {} + for bnd in idx_lambda.bindings.values() + } + + for name, bnd in idx_lambda.bindings.items(): + for iout_axis in range(expr.ndim): + ibnd_axis = _get_iaxis_in_binding(idx_lambda.expr, iout_axis, name) + if isinstance(ibnd_axis, _SingleAxisDependentAccess): + result[bnd][iout_axis] = ibnd_axis.iaxis + + return Map({k: Map(v) for k, v in result.items()}) + + +class _IndirectionPusher(Mapper): + """ + Mapper to move the indirections in the array expression closer to the + materialized nodes of the graph. The logic implemented in the mapper + complements the implementation in :class:`_LegallyAxisReorderingFinder`. + """ + + def __init__(self) -> None: + self.get_reordarable_axes = _LegallyAxisReorderingFinder() + self._cache: Dict[Tuple[ArrayOrNames, Map[int, IndexT]], + ArrayOrNames] = {} + super().__init__() + + def rec(self, # type: ignore[override] + expr: MappedT, + indices: Tuple[IndexT, ...]) -> MappedT: + assert len(indices) == expr.ndim + key = (expr, indices) + try: + # type-ignore-reason: parametric mapping types aren't a thing in 'typing' + return self._cache[key] # type: ignore[return-value] + except KeyError: + result = Mapper.rec(self, expr, indices) + self._cache[key] = result + return result # type: ignore[no-any-return] + + def __call__(self, # type: ignore[override] + expr: MappedT, + indices: Map[int, IndexT]) -> MappedT: + return self.rec(expr, indices) + + def _map_materialized(self, + expr: Array, + indices: Tuple[IndexT, ...]) -> Array: + if all(_is_trivial_slice(dim, idx) + for dim, idx in zip(expr.shape, indices)): + return expr + return expr[*indices] + + def map_dict_of_named_arrays(self, + expr: DictOfNamedArrays, + *args: Any, **kwargs: Any) -> Any: + raise ValueError("_IndirectionPusher cannot map AbstractResultOfNamedArrays") + + map_placeholder = _map_materialized + map_data_wrapper = _map_materialized + + def map_index_lambda(self, + expr: IndexLambda, + indices: Tuple[IndexT, ...], + ) -> Array: + if _is_materialized(expr): + # FIXME: Move this logic to .rec (Why on earth do we need) + # to copy the damn node??? + + # do not propagate the indexings to the bindings. + expr = IndexLambda(expr.expr, + expr.shape, + expr.dtype, + Map({name: self.rec(bnd, Map()) + for name, bnd in expr.bindings.items()}), + expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes,) + return self._map_materialized(expr, indices) + + # FIXME: + # This is the money shot. Over here we need to figure out the index + # propagation logic. + + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + + new_bindings = { + name: self.rec(bnd, + Map({iout_axis_to_bnd_axis[bnd][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items() + if iout_axis in iout_axis_to_bnd_axis[bnd]}) + ) + for name, bnd in expr.bindings.items() + } + + # {{{ compute the new shape after propagating the indirections + + iaxis_to_new_shape: Dict[int, ShapeComponent] = { + idim: axis_len + for idim, axis_len in enumerate(expr.shape) + if idim not in index_stacks + } + + for iaxis in index_stacks: + for bnd_name in expr.bindings: + ibnd_axis = _get_iaxis_in_binding(expr.expr, iaxis, bnd_name) + assert isinstance(ibnd_axis, _SingleAxisDependentAccess) + new_bnd_axis_len = new_bindings[bnd_name].shape[ibnd_axis.iaxis] + assert are_shape_components_equal( + iaxis_to_new_shape.setdefault(iaxis, new_bnd_axis_len), + new_bnd_axis_len) + + # }}} + + assert len(iaxis_to_new_shape) == expr.ndim + return IndexLambda(expr=expr.expr, + bindings=Map(new_bindings), + dtype=expr.dtype, + shape=tuple(iaxis_to_new_shape[idim] + for idim in range(expr.ndim)), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes) + + def map_basic_index(self, + expr: BasicIndex, + index_stacks: Map[int, IndexStackT]) -> Array: + + if _is_materialized(expr): + # do not propagate the indexings to the indexee. + expr = BasicIndex(self.rec(expr.array, Map()), + expr.indices, + tags=expr.tags, + axes=expr.axes) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_iarray_axis: Dict[int, int] = {} + iout_axis = 0 + + for iarray_axis, idx in enumerate(expr.indices): + if isinstance(idx, NormalizedSlice): + iout_axis_to_iarray_axis[iout_axis] = iarray_axis + iout_axis += 1 + + assert iout_axis == expr.ndim + # initialize from previous indexing operations + array_index_stacks = { + iout_axis_to_iarray_axis[iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items() + } + + # indices that cannot be propagated to expr.array + unreordered_i_indices: List[int] = [] + + for iarray_axis, idx in enumerate(expr.indices): + if isinstance(idx, NormalizedSlice): + if iarray_axis in self.get_reordarable_axes(expr.array): + array_index_stacks[iarray_axis] = ( + array_index_stacks.get(iarray_axis, ()) + (idx,) + ) + else: + if not _is_trivial_slice(expr.array.shape[iarray_axis], + idx): + unreordered_i_indices.append(iarray_axis) + + if unreordered_i_indices: + raise NotImplementedError("Partially pushing the indexers is not yet" + " implemented.") + + # FIXME: Think about metadata preservation?? + return self.rec(expr.array, Map(array_index_stacks)) + + def map_contiguous_advanced_index(self, + expr: AdvancedIndexInContiguousAxes, + index_stacks: Map[int, IndexStackT] + ) -> Array: + from pytato.utils import partition, get_shape_after_broadcasting + if _is_materialized(expr): + # do not propagate the indexings to the indexee. + expr = AdvancedIndexInContiguousAxes(self.rec(expr.array, Map()), + expr.indices, + tags=expr.tags, + axes=expr.axes) + return self._map_materialized(expr, index_stacks) + + array_index_stacks: Dict[int, IndexStackT] = {} + unreorderable_axes: List[int] = [] + iout_axis = 0 + + i_adv_indices, i_basic_indices = partition(lambda idx: isinstance( + expr.indices[idx], + NormalizedSlice), + range(len(expr.indices))) + + pre_basic_indices, post_basic_indices = partition( + lambda idx: idx > i_adv_indices[0], + i_basic_indices, + ) + ary_indices = tuple(i_idx + for i_idx, idx in enumerate(expr.indices) + if isinstance(idx, Array)) + adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx] + for i_idx in i_adv_indices]) + + for iarray_axis in pre_basic_indices: + if iarray_axis in self.get_reordarable_axes(expr.array): + # type-ignore-reason: mypy cannot infer iarray_axis corresponds + # to a slice + array_index_stacks[iarray_axis] = ( + index_stacks.get(iout_axis, ()) + + (expr.indices[iarray_axis],)) # type: ignore[operator] + else: + # type-ignore-reason: mypy cannot infer iarray_axis corresponds + # to a slice + if not _is_trivial_slice( + expr.array.shape[iarray_axis], + expr.indices[iarray_axis] # type: ignore[arg-type] + ): + unreorderable_axes.append(iarray_axis) + iout_axis += 1 + + # type-ignore-reason: mypy cannot infer ary_indices corresponds + # to indirections + if (len(ary_indices) == 1 + and (expr.indices[ary_indices[0]].ndim # type: ignore[union-attr] + == 1) + and ary_indices[0] in self.get_reordarable_axes(expr.array)): + + array_index_stacks[ary_indices[0]] = ( + index_stacks.get(iout_axis, ()) + + (expr.indices[ary_indices[0]],)) # type: ignore[operator] + iout_axis += 1 + else: + iout_axis += len(adv_idx_shape) + unreorderable_axes.extend(ary_indices) + + for iarray_axis in post_basic_indices: + if iarray_axis in self.get_reordarable_axes(expr.array): + # type-ignore-reason: mypy cannot infer post_basic_indices + # corresponds to slices + array_index_stacks[iarray_axis] = ( + index_stacks.get(iout_axis, ()) + + (expr.indices[iarray_axis],)) # type: ignore[operator] + else: + if not _is_trivial_slice( + expr.array.shape[iarray_axis], + expr.indices[iarray_axis]): # type: ignore[arg-type] + unreorderable_axes.append(iarray_axis) + iout_axis += 1 + + assert iout_axis == expr.ndim + + if unreorderable_axes: + raise NotImplementedError("Partially pushing the indexers is not yet" + " implemented.") + + return self.rec(expr.array, Map(array_index_stacks)) + + def map_non_contiguous_advanced_index( + self, + expr: AdvancedIndexInNoncontiguousAxes, + index_stacks: Map[int, IndexStackT]) -> Array: + from pytato.utils import partition, get_shape_after_broadcasting + if _is_materialized(expr): + # do not propagate the indexings to the indexee. + expr = AdvancedIndexInNoncontiguousAxes(self.rec(expr.array, Map()), + expr.indices, + tags=expr.tags, + axes=expr.axes) + return self._map_materialized(expr, index_stacks) + + array_index_stacks: Dict[int, IndexStackT] = {} + unreorderable_axes: List[int] = [] + iout_axis = 0 + + i_adv_indices, i_basic_indices = partition(lambda idx: isinstance( + expr.indices[idx], + NormalizedSlice), + range(len(expr.indices))) + + ary_indices = tuple(i_idx + for i_idx, idx in enumerate(expr.indices) + if isinstance(idx, Array)) + adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx] + for i_idx in i_adv_indices]) + + if (len(ary_indices) == 1 + and (expr.indices[ary_indices[0]].ndim # type: ignore[union-attr] + == 1) + and ary_indices[0] in self.get_reordarable_axes(expr.array)): + + # type-ignore-reason: mypy cannot infer ary_indices correspond to + # indirections + array_index_stacks[ary_indices[0]] = ( + index_stacks.get(iout_axis, ()) + + (expr.indices[ary_indices[0]],)) # type: ignore[operator] + iout_axis += 1 + else: + iout_axis += len(adv_idx_shape) + unreorderable_axes.extend(ary_indices) + + for iarray_axis in i_basic_indices: + if iarray_axis in self.get_reordarable_axes(expr.array): + # type-ignore-reason: mypy cannot infer ary_indices correspond to + # slices + array_index_stacks[iarray_axis] = ( + index_stacks.get(iout_axis, ()) + + (expr.indices[iarray_axis],)) # type: ignore[operator] + else: + if not _is_trivial_slice( + expr.array.shape[iarray_axis], + expr.indices[iarray_axis]): # type: ignore[arg-type] + unreorderable_axes.append(iarray_axis) + iout_axis += 1 + + assert iout_axis == expr.ndim + + if unreorderable_axes: + raise NotImplementedError("Partially pushing the indexers is not yet" + " implemented.") + + return self.rec(expr.array, Map(array_index_stacks)) + + def map_stack(self, + expr: Stack, + index_stacks: Map[int, IndexStackT]) -> Array: + if _is_materialized(expr): + # do not propagate the indexings to the bindings. + expr = Stack( + arrays=tuple(self.rec(ary, Map()) for ary in expr.arrays), + axis=expr.axis, + tags=expr.tags, + axes=expr.axes, + ) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + assert expr.axis not in index_stacks + return Stack( + arrays=tuple( + self.rec(ary, + Map({iout_axis_to_bnd_axis[ary][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items()})) + for ary in expr.arrays), + axis=expr.axis, + tags=expr.tags, + axes=expr.axes, + ) + + def map_concatenate(self, + expr: Concatenate, + index_stacks: Map[int, IndexStackT]) -> Array: + if _is_materialized(expr): + # do not propagate the indexings to the bindings. + expr = Concatenate( + arrays=tuple(self.rec(ary, Map()) for ary in expr.arrays), + axis=expr.axis, + tags=expr.tags, + axes=expr.axes, + ) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + assert expr.axis not in index_stacks + return Concatenate( + arrays=tuple( + self.rec(ary, + Map({iout_axis_to_bnd_axis[ary][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items()})) + for ary in expr.arrays), + axis=expr.axis, + tags=expr.tags, + axes=expr.axes, + ) + + def map_einsum(self, + expr: Einsum, + index_stacks: Map[int, IndexStackT]) -> Array: + + if _is_materialized(expr): + # do not propagate the indexings to the bindings. + expr = Einsum( + expr.access_descriptors, + args=tuple(self.rec(arg, Map()) for arg in expr.args), + index_to_access_descr=expr.index_to_access_descr, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + axes=expr.axes, + ) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + return Einsum( + expr.access_descriptors, + args=tuple( + self.rec(arg, + Map({iout_axis_to_bnd_axis[arg][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items() + if iout_axis in iout_axis_to_bnd_axis[arg]}) + ) + for arg in expr.args), + index_to_access_descr=expr.index_to_access_descr, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + axes=expr.axes, + ) + + def map_roll(self, + expr: Roll, + index_stacks: Map[int, IndexStackT]) -> Array: + if _is_materialized(expr): + # do not propagate the indexings to the bindings. + expr = Roll( + self.rec(expr.array, Map()), + expr.shift, + expr.axis, + tags=expr.tags, + axes=expr.axes,) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + return Roll( + self.rec(expr.array, + Map({iout_axis_to_bnd_axis[expr.array][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items()}) + ), + expr.shift, + expr.axis, + tags=expr.tags, + axes=expr.axes,) + + def map_reshape(self, + expr: Reshape, + index_stacks: Map[int, IndexStackT]) -> Array: + + if _is_materialized(expr): + # do not propagate the indexings to the bindings. + expr = Reshape( + self.rec(expr.array, Map()), + expr.newshape, + expr.order, + tags=expr.tags, + axes=expr.axes,) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + return Reshape( + self.rec( + expr.array, + Map({iout_axis_to_bnd_axis[expr.array][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items()}) + ), + expr.newshape, + expr.order, + tags=expr.tags, + axes=expr.axes,) + + def map_axis_permutation(self, + expr: AxisPermutation, + index_stacks: Map[int, IndexStackT]) -> Array: + + if _is_materialized(expr): + # do not propagate the indexings to the bindings. + expr = AxisPermutation( + self.rec(expr.array, Map()), + expr.axis_permutation, + tags=expr.tags, + axes=expr.axes, + ) + return self._map_materialized(expr, index_stacks) + + iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr) + return AxisPermutation( + self.rec( + expr.array, + Map({iout_axis_to_bnd_axis[expr.array][iout_axis]: index_stack + for iout_axis, index_stack in index_stacks.items()}) + ), + expr.axis_permutation, + tags=expr.tags, + axes=expr.axes, + ) + + +def push_axis_indirections_towards_materialized_nodes(expr: MappedT + ) -> MappedT: + """ + Returns a copy of *expr* with the indirections propagated closer to the + materialized nodes. + """ + mapper = _IndirectionPusher() + + return mapper(expr, Map()) + + +# {{{ fold indirection constants + +class _ConstantIndirectionArrayCollector(CombineMapper[FrozenSet[Array]]): + def __init__(self) -> None: + from pytato.transform import InputGatherer + super().__init__() + self.get_inputs = InputGatherer() + + def combine(self, *args: FrozenSet[Array]) -> FrozenSet[Array]: + from functools import reduce + return reduce(frozenset.union, args, frozenset()) + + def _map_input_base(self, expr: InputArgumentBase) -> FrozenSet[Array]: + return frozenset() + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + map_size_param = _map_input_base + + def _map_index_base(self, expr: IndexBase) -> FrozenSet[Array]: + rec_results: List[FrozenSet[Array]] = [] + + rec_results.append(self.rec(expr.array)) + + for idx in expr.indices: + if isinstance(idx, Array): + if any(isinstance(inp, Placeholder) + for inp in self.get_inputs(idx)): + rec_results.append(self.rec(idx)) + else: + rec_results.append(frozenset([idx])) + + return self.combine(*rec_results) + + +def fold_constant_indirections( + expr: MappedT, + evaluator: Callable[[DictOfNamedArrays], Mapping[str, DataWrapper]] +) -> MappedT: + """ + Returns a copy of *expr* with constant indirection expressions frozen. + + :arg evaluator: A callable that takes in a + :class:`~pytato.array.DictOfNamedArrays` and returns a mapping from the + name of every named array to it's corresponding evaluated array as an + instance of :class:`~pytato.array.DataWrapper`. + """ + from pytools import UniqueNameGenerator + from pytato.array import make_dict_of_named_arrays + import collections.abc as abc + from pytato.transform import map_and_copy + + vng = UniqueNameGenerator() + arys_to_evaluate = _ConstantIndirectionArrayCollector()(expr) + dict_of_named_arrays = make_dict_of_named_arrays( + {vng("_pt_folded_cnst"): ary for ary in arys_to_evaluate} + ) + del arys_to_evaluate + evaluated_arys = evaluator(dict_of_named_arrays) + + if not isinstance(evaluated_arys, abc.Mapping): + raise TypeError("evaluator did not return a mapping") + + if set(evaluated_arys.keys()) != set(dict_of_named_arrays.keys()): + raise ValueError("evaluator must return a mapping with " + f"the keys: '{set(dict_of_named_arrays.keys())}'.") + + for key, ary in evaluated_arys.items(): + if not isinstance(ary, DataWrapper): + raise TypeError(f"evaluated array for '{key}' not a DataWrapper") + + before_to_after_subst = { + dict_of_named_arrays._data[name]: evaluated_ary + for name, evaluated_ary in evaluated_arys.items() + } + + def _replace_with_folded_constants(subexpr: ArrayOrNames) -> ArrayOrNames: + if isinstance(subexpr, Array): + return before_to_after_subst.get(subexpr, subexpr) + else: + return subexpr + + return map_and_copy(expr, _replace_with_folded_constants) + +# }}} + +# vim: foldmethod=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index 0f1456d9b..3997530f2 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -44,7 +44,7 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa import pytato as pt -from testlib import assert_allclose_to_numpy, get_random_pt_dag +from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref import pymbolic.primitives as p @@ -2002,6 +2002,147 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) +def _evaluator_for_indirection_folding(cl_ctx, dictofarys): + from immutables import Map + cq = cl.CommandQueue(cl_ctx) + _, out_dict = pt.generate_loopy(dictofarys)(cq) + return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()}) + + +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) +def test_push_indirections_0(ctx_factory, fold_constant_idxs): + from testlib import (are_all_indexees_materialized_nodes, + are_all_indexer_arrays_datawrappers) + + cl_ctx = cl.create_some_context() + rng = np.random.default_rng(0) + x_np = rng.random((10, 4)) + map1_np = rng.integers(0, 10, size=17) + map2_np = rng.integers(0, 17, size=29) + + x = pt.make_data_wrapper(x_np) + map1 = pt.make_data_wrapper(map1_np) + map2 = pt.make_data_wrapper(map2_np) + + y = 3.14 * ((42*((2*x)[map1]))[map2]) + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) + ) + + if fold_constant_idxs: + assert not are_all_indexer_arrays_datawrappers(y_transformed) + y_transformed = pt.fold_constant_indirections( + y_transformed, + lambda doa: _evaluator_for_indirection_folding(cl_ctx, + doa) + ) + assert are_all_indexer_arrays_datawrappers(y_transformed) + + auto_test_vs_ref(cl_ctx, y, y_transformed) + assert are_all_indexees_materialized_nodes(y_transformed) + + +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) +def test_push_indirections_1(ctx_factory, fold_constant_idxs): + from testlib import (are_all_indexees_materialized_nodes, + are_all_indexer_arrays_datawrappers) + + cl_ctx = cl.create_some_context() + rng = np.random.default_rng(0) + x_np = rng.random((100, 4)) + map1_np = rng.integers(0, 20, size=17) + + x = pt.make_data_wrapper(x_np) + map1 = pt.make_data_wrapper(map1_np) + + y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1]) + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) + ) + + if fold_constant_idxs: + assert not are_all_indexer_arrays_datawrappers(y_transformed) + y_transformed = pt.fold_constant_indirections( + y_transformed, + lambda doa: _evaluator_for_indirection_folding(cl_ctx, + doa) + ) + assert are_all_indexer_arrays_datawrappers(y_transformed) + + auto_test_vs_ref(cl_ctx, y, y_transformed) + assert are_all_indexees_materialized_nodes(y_transformed) + + +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) +def test_push_indirections_2(ctx_factory, fold_constant_idxs): + from testlib import (are_all_indexees_materialized_nodes, + are_all_indexer_arrays_datawrappers) + + cl_ctx = cl.create_some_context() + rng = np.random.default_rng(0) + x_np = rng.random((100, 10)) + map1_np = rng.integers(0, 20, size=17) + map2_np = rng.integers(0, 4, size=29) + + x = pt.make_data_wrapper(x_np) + map1 = pt.make_data_wrapper(map1_np) + map2 = pt.make_data_wrapper(map2_np) + + y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7] + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) + ) + + if fold_constant_idxs: + assert not are_all_indexer_arrays_datawrappers(y_transformed) + y_transformed = pt.fold_constant_indirections( + y_transformed, + lambda doa: _evaluator_for_indirection_folding(cl_ctx, + doa) + ) + assert are_all_indexer_arrays_datawrappers(y_transformed) + + auto_test_vs_ref(cl_ctx, y, y_transformed) + assert are_all_indexees_materialized_nodes(y_transformed) + + +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) +def test_push_indirections_3(ctx_factory, fold_constant_idxs): + from testlib import (are_all_indexees_materialized_nodes, + are_all_indexer_arrays_datawrappers) + + cl_ctx = cl.create_some_context() + rng = np.random.default_rng(0) + x_np = rng.random((10, 4)) + map1_np = rng.integers(0, 10, size=17) + map2_np = rng.integers(0, 17, size=29) + map3_np = rng.integers(0, 4, size=60) + map4_np = rng.integers(0, 60, size=22) + + x = pt.make_data_wrapper(x_np) + map1 = pt.make_data_wrapper(map1_np) + map2 = pt.make_data_wrapper(map2_np) + map3 = pt.make_data_wrapper(map3_np) + map4 = pt.make_data_wrapper(map4_np) + + y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4]) + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) + ) + + if fold_constant_idxs: + assert not are_all_indexer_arrays_datawrappers(y_transformed) + y_transformed = pt.fold_constant_indirections( + y_transformed, + lambda doa: _evaluator_for_indirection_folding(cl_ctx, + doa) + ) + assert are_all_indexer_arrays_datawrappers(y_transformed) + + auto_test_vs_ref(cl_ctx, y, y_transformed) + assert are_all_indexees_materialized_nodes(y_transformed) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_transform.py b/test/test_transform.py new file mode 100644 index 000000000..d5467bbc8 --- /dev/null +++ b/test/test_transform.py @@ -0,0 +1,52 @@ +__copyright__ = "Copyright (C) 2024 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytato as pt +import numpy as np + + +def test_push_indirections_on_dg_flux_terms(): + nel = 1000 + ndof = 35 # we love 3D-P4 cells + n_intface = 2800 + n_facedof = 20 + n_face = 4 + + u1 = pt.make_placeholder("u1", (nel, ndof)) + u2 = pt.make_placeholder("u2", (nel, ndof)) + u = u1 + u2 + from_el_indices = pt.make_placeholder("from_el_indices", + n_intface, + np.int32) + dof_pick_list_indices = pt.make_placeholder("dof_pick_list_indices", + n_intface, + np.int32) + dof_pick_lists = pt.make_placeholder("dof_pick_lists", + (n_face, n_facedof), + np.int32) + result = u[from_el_indices.reshape(-1, 1), + dof_pick_lists[dof_pick_list_indices]] + transformed = pt.push_axis_indirections_towards_materialized_nodes(result) + assert transformed == (u1[from_el_indices.reshape(-1, 1), + dof_pick_lists[dof_pick_list_indices]] + + u2[from_el_indices.reshape(-1, 1), + dof_pick_lists[dof_pick_list_indices]]) diff --git a/test/testlib.py b/test/testlib.py index 5cd1342d3..2e15fc564 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -6,10 +6,11 @@ import pyopencl as cl import numpy as np import pytato as pt -from pytato.transform import Mapper +from pytato.transform import Mapper, CombineMapper from pytato.array import (Array, Placeholder, Stack, Roll, AxisPermutation, DataWrapper, Reshape, - Concatenate) + Concatenate, DictOfNamedArrays, IndexBase, + SizeParam) from pytools.tag import Tag @@ -369,4 +370,138 @@ class QuuxTag(TestlibTag): # }}} + +# {{{ utilities for test_push_indirections_* + +class _IndexeeArraysMaterializedChecker(CombineMapper[bool]): + def combine(self, *args: bool) -> bool: + return all(args) + + def map_placeholder(self, expr: Placeholder) -> bool: + return True + + def map_data_wrapper(self, expr: DataWrapper) -> bool: + return True + + def map_size_param(self, expr: SizeParam) -> bool: + return True + + def _map_index_base(self, expr: IndexBase) -> bool: + from pytato.transform.indirections import _is_materialized + return self.combine( + _is_materialized(expr.array) or isinstance(expr.array, IndexBase), + self.rec(expr.array) + ) + + +def are_all_indexees_materialized_nodes( + expr: Union[Array, DictOfNamedArrays]) -> bool: + """ + Returns *True* only if all indexee arrays are either materialized nodes, + OR, other indexing nodes that have materialized indexees. + """ + return _IndexeeArraysMaterializedChecker()(expr) + + +class _IndexerArrayDatawrapperChecker(CombineMapper[bool]): + def combine(self, *args: bool) -> bool: + return all(args) + + def map_placeholder(self, expr: Placeholder) -> bool: + return True + + def map_data_wrapper(self, expr: DataWrapper) -> bool: + return True + + def map_size_param(self, expr: SizeParam) -> bool: + return True + + def _map_index_base(self, expr: IndexBase) -> bool: + return self.combine( + *[isinstance(idx, DataWrapper) + for idx in expr.indices + if isinstance(idx, Array)], + super()._map_index_base(expr), + ) + + +def are_all_indexer_arrays_datawrappers( + expr: Union[Array, DictOfNamedArrays]) -> bool: + """ + Returns *True* only if all indexer arrays are instances of + :class:`~pytato.array.DataWrapper`. + """ + return _IndexerArrayDatawrapperChecker()(expr) + +# }}} + + +# {{{ auto_test_vs_ref + +class AutoTestFailureException(RuntimeError): + """ + Raised by :func:`auto_test_vs_ref` when the expressions do NOT match. + """ + + +def auto_test_vs_ref(cl_ctx: "cl.Context", + actual: Union[Array, DictOfNamedArrays], + desired: Union[Array, DictOfNamedArrays], + *, + rtol: float = 1e-07, + atol: float = 0) -> None: + import pyopencl.array as cla + import loopy as lp + from pytato.transform import InputGatherer + + if isinstance(desired, Array): + if not isinstance(actual, Array): + raise AutoTestFailureException("'actual' is not an 'Array'") + + desired = pt.make_dict_of_named_arrays({"_pt_out": desired}) + actual = pt.make_dict_of_named_arrays({"_pt_out": actual}) + else: + assert isinstance(desired, DictOfNamedArrays) + if not isinstance(actual, DictOfNamedArrays): + raise AutoTestFailureException("'actual' is not" + " a 'DictOfNamedArrays'") + + cq = cl.CommandQueue(cl_ctx) + + if (any(isinstance(inp, Placeholder) + for inp in InputGatherer()(actual)) + or any(isinstance(inp, Placeholder) + for inp in InputGatherer()(desired))): + raise NotImplementedError("Expression graphs with placeholders not" + " yet supported in auto_test_vs_ref.") + + actual_prg = pt.generate_loopy(actual, options=lp.Options(return_dict=True)) + desired_prg = pt.generate_loopy(desired, options=lp.Options(return_dict=True)) + + _, actual_out_dict = actual_prg(cq) + _, desired_out_dict = desired_prg(cq) + + if set(actual_out_dict) != set(desired_out_dict): + raise AutoTestFailureException( + "Different outputs obtained from the 2 expressions. " + f" '{set(actual_out_dict.keys())}' vs '{set(desired_out_dict.keys())}'" + ) + + for output_name, desired_out in desired_out_dict.items(): + actual_out = actual_out_dict[output_name] + + if isinstance(desired_out, cla.Array): + desired_out = desired_out.get() + if isinstance(actual_out, cla.Array): + actual_out = actual_out.get() + + try: + np.testing.assert_allclose(actual_out, desired_out, + rtol=rtol, atol=atol) + except AssertionError as e: + raise AutoTestFailureException( + f"While comparing '{output_name}': \n{e.args[0]}") + +# }}} + # vim: foldmethod=marker