Skip to content

Commit f7d0d54

Browse files
inducermajosm
andauthored
Mapper typing: Array->Array, Names->Names (#610)
* Mapper typing: Array->Array, Names->Names * more typing work * More type fixes * MappedT -> ArrayOrNamesT * handle DictOfNamedArrays in materialize_with_mpms outside of the mapper * a few more tweaks * add deduplicate * fix doc * Update baseline --------- Co-authored-by: Matthew Smith <mjsmith6@illinois.edu>
1 parent 6b45ab7 commit f7d0d54

File tree

17 files changed

+130
-546
lines changed

17 files changed

+130
-546
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 464 deletions
Large diffs are not rendered by default.

examples/advection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_advection_convergence(order, flux_type):
156156
op = AdvectionOperator(discr, c=1, flux_type=flux_type,
157157
dg_ops=dg_ops)
158158
result = op.apply(u)
159-
result = pt.transform.Deduplicator()(result)
159+
result = pt.transform.deduplicate(result)
160160

161161
prog = pt.generate_loopy(result, cl_device=queue.device)
162162

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ reportImplicitStringConcatenation = "none"
157157
reportUnnecessaryIsInstance = "none"
158158
reportUnusedCallResult = "none"
159159
reportExplicitAny = "none"
160+
reportPrivateUsage = "hint"
161+
reportUnusedParameter = "hint"
160162

161163
# This reports even cycles that are qualified by 'if TYPE_CHECKING'. Not what
162164
# we care about at this moment.
@@ -189,6 +191,8 @@ reportArgumentType = "hint"
189191
reportUnknownMemberType = "hint"
190192
reportUnknownParameterType = "hint"
191193
reportAny = "none"
194+
reportPrivateUsage = "hint"
195+
reportUnusedParameter = "hint"
192196

193197
[[tool.basedpyright.executionEnvironments]]
194198
root = "examples"
@@ -201,6 +205,8 @@ reportArgumentType = "hint"
201205
reportUnknownMemberType = "hint"
202206
reportUnknownParameterType = "hint"
203207
reportAny = "none"
208+
reportPrivateUsage = "hint"
209+
reportUnusedParameter = "hint"
204210

205211
[tool.typos.default]
206212
extend-ignore-re = [

pytato/analysis/__init__.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,11 @@ def map_named_call_result(self, expr: NamedCallResult) -> None:
211211
# }}}
212212

213213

214-
def get_nusers(outputs: Array | DictOfNamedArrays) -> Mapping[Array, int]:
214+
def get_nusers(outputs: ArrayOrNames) -> Mapping[Array, int]:
215215
"""
216-
For the DAG *outputs*, returns the mapping from each node to the number of
216+
For the DAG *outputs*, returns the mapping from each array node to the number of
217217
nodes using its value within the DAG given by *outputs*.
218218
"""
219-
from pytato.codegen import normalize_outputs
220-
outputs = normalize_outputs(outputs)
221219
nuser_collector = NUserCollector()
222220
nuser_collector(outputs)
223221
return nuser_collector.nusers
@@ -508,7 +506,7 @@ def post_visit(self, expr: Any) -> None:
508506

509507

510508
def get_node_type_counts(
511-
outputs: Array | DictOfNamedArrays,
509+
outputs: ArrayOrNames,
512510
count_duplicates: bool = False
513511
) -> dict[type[Any], int]:
514512
"""
@@ -518,17 +516,14 @@ def get_node_type_counts(
518516
Instances of `DictOfNamedArrays` are excluded from counting.
519517
"""
520518

521-
from pytato.codegen import normalize_outputs
522-
outputs = normalize_outputs(outputs)
523-
524519
ncm = NodeCountMapper(count_duplicates)
525520
ncm(outputs)
526521

527522
return ncm.expr_type_counts
528523

529524

530525
def get_num_nodes(
531-
outputs: Array | DictOfNamedArrays,
526+
outputs: ArrayOrNames,
532527
count_duplicates: bool | None = None
533528
) -> int:
534529
"""
@@ -544,9 +539,6 @@ def get_num_nodes(
544539
DeprecationWarning, stacklevel=2)
545540
count_duplicates = True
546541

547-
from pytato.codegen import normalize_outputs
548-
outputs = normalize_outputs(outputs)
549-
550542
ncm = NodeCountMapper(count_duplicates)
551543
ncm(outputs)
552544

@@ -586,14 +578,10 @@ def post_visit(self, expr: Any) -> None:
586578
self.expr_multiplicity_counts[expr] += 1
587579

588580

589-
def get_node_multiplicities(
590-
outputs: Array | DictOfNamedArrays) -> dict[Array, int]:
581+
def get_node_multiplicities(outputs: ArrayOrNames) -> dict[Array, int]:
591582
"""
592583
Returns the multiplicity per `expr`.
593584
"""
594-
from pytato.codegen import normalize_outputs
595-
outputs = normalize_outputs(outputs)
596-
597585
nmm = NodeMultiplicityMapper()
598586
nmm(outputs)
599587

@@ -640,7 +628,7 @@ def post_visit(self, expr: Any) -> None:
640628
self.count += 1
641629

642630

643-
def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:
631+
def get_num_call_sites(outputs: ArrayOrNames) -> int:
644632
"""Returns the number of nodes in DAG *outputs*."""
645633

646634
from pytato.codegen import normalize_outputs
@@ -700,7 +688,7 @@ def rec(self, expr: ArrayOrNames) -> int:
700688

701689

702690
def get_num_tags_of_type(
703-
outputs: Array | DictOfNamedArrays,
691+
outputs: ArrayOrNames,
704692
tag_types: type[pytools.tag.Tag] | Iterable[type[pytools.tag.Tag]]) -> int:
705693
"""Returns the number of nodes in DAG *outputs* that are tagged with
706694
all the tag types in *tag_types*."""

pytato/codegen.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010

1111
from __future__ import annotations
1212

13-
from typing import TypeAlias
14-
15-
from typing_extensions import TypeIs
16-
1713

1814
__copyright__ = """Copyright (C) 2020 Matt Wala"""
1915

@@ -38,15 +34,17 @@
3834
"""
3935

4036
import dataclasses
41-
from typing import TYPE_CHECKING, Any
37+
from typing import TYPE_CHECKING, Any, TypeAlias
4238

4339
from immutabledict import immutabledict
40+
from typing_extensions import TypeIs
4441

4542
import loopy as lp
4643
from pymbolic.mapper.optimize import optimize_mapper
4744
from pytools import UniqueNameGenerator
4845

4946
from pytato.array import (
47+
AbstractResultWithNamedArrays,
5048
Array,
5149
DataInterface,
5250
DataWrapper,
@@ -241,24 +239,19 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
241239

242240

243241
def normalize_outputs(
244-
result: Array | DictOfNamedArrays | dict[str, Array]
245-
) -> DictOfNamedArrays:
242+
result: ArrayOrNames | dict[str, Array]
243+
) -> AbstractResultWithNamedArrays:
246244
"""Convert outputs of a computation to the canonical form.
247245
248246
Performs a conversion to :class:`~pytato.DictOfNamedArrays` if necessary.
249247
250248
:param result: Outputs of the computation.
251249
"""
252-
if not isinstance(result, Array | DictOfNamedArrays | dict):
253-
raise TypeError("outputs of the computation should be "
254-
"either an Array or a DictOfNamedArrays")
255-
256250
if isinstance(result, Array):
257251
outputs = make_dict_of_named_arrays({"_pt_out": result})
258252
elif isinstance(result, dict):
259253
outputs = make_dict_of_named_arrays(result)
260254
else:
261-
assert isinstance(result, DictOfNamedArrays)
262255
outputs = result
263256

264257
return outputs

pytato/target/loopy/codegen.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import pytato.reductions as red
4242
import pytato.scalar_expr as scalar_expr
4343
from pytato.array import (
44+
AbstractResultWithNamedArrays,
4445
Array,
4546
DataWrapper,
4647
DictOfNamedArrays,
@@ -1036,7 +1037,7 @@ def get_initial_codegen_state(target: LoopyTarget,
10361037

10371038
# {{{ generate_loopy
10381039

1039-
def generate_loopy(result: Array | DictOfNamedArrays | dict[str, Array],
1040+
def generate_loopy(result: Array | AbstractResultWithNamedArrays | dict[str, Array],
10401041
target: LoopyTarget | None = None,
10411042
options: lp.Options | None = None,
10421043
*,
@@ -1083,7 +1084,11 @@ def generate_loopy(result: Array | DictOfNamedArrays | dict[str, Array],
10831084
"""
10841085

10851086
result_is_dict = isinstance(result, dict | DictOfNamedArrays)
1086-
orig_outputs: DictOfNamedArrays = normalize_outputs(result)
1087+
orig_outputs: AbstractResultWithNamedArrays = normalize_outputs(result)
1088+
1089+
if not isinstance(orig_outputs, DictOfNamedArrays):
1090+
raise NotImplementedError(
1091+
f"not implemented for {type(result).__name__}.")
10871092

10881093
del result
10891094

pytato/transform/__init__.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
import numpy as np
4444
from immutabledict import immutabledict
45-
from typing_extensions import Never, Self
45+
from typing_extensions import Never, Self, override
4646

4747
from pymbolic.mapper.optimize import optimize_mapper
4848

@@ -90,8 +90,11 @@
9090

9191

9292
ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays
93-
MappedT = TypeVar("MappedT",
94-
Array, AbstractResultWithNamedArrays, ArrayOrNames)
93+
ArrayOrNamesTc = TypeVar("ArrayOrNamesTc",
94+
Array, AbstractResultWithNamedArrays, DictOfNamedArrays)
95+
ArrayOrNamesOrFunctionDefTc = TypeVar("ArrayOrNamesOrFunctionDefTc",
96+
Array, AbstractResultWithNamedArrays, DictOfNamedArrays,
97+
FunctionDefinition)
9598
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
9699
R = frozenset[Array]
97100

@@ -116,6 +119,7 @@
116119
.. autoclass:: TopoSortMapper
117120
.. autoclass:: CachedMapAndCopyMapper
118121
.. autofunction:: copy_dict_of_named_arrays
122+
.. autofunction:: deduplicate
119123
.. autofunction:: get_dependencies
120124
.. autofunction:: map_and_copy
121125
.. autofunction:: materialize_with_mpms
@@ -145,9 +149,15 @@
145149
146150
.. class:: ArrayOrNames
147151
148-
.. class:: MappedT
152+
.. class:: ArrayOrNamesTc
149153
150-
A type variable representing the input type of a :class:`Mapper`.
154+
A type variable representing the input type of a :class:`Mapper`, excluding
155+
functions.
156+
157+
.. class:: ArrayOrNamesOrFunctionDefTc
158+
159+
A type variable representing the input type of a :class:`Mapper`, including
160+
functions.
151161
152162
.. class:: ResultT
153163
@@ -702,6 +712,22 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self:
702712
err_on_created_duplicate=function_cache.err_on_created_duplicate,
703713
_function_cache=function_cache)
704714

715+
@override
716+
# This overrides incompatibly on purpose, in order to convey stronger
717+
# guarantees. We're not trying to be very mapper-polymorphic, so
718+
# IMO this inconsistency is "worth it(tm)". -AK, 2025-06-16
719+
def __call__( # pyright: ignore[reportIncompatibleMethodOverride]
720+
self,
721+
expr: ArrayOrNamesOrFunctionDefTc,
722+
) -> ArrayOrNamesOrFunctionDefTc:
723+
"""Handle the mapping of *expr*."""
724+
if isinstance(expr, Array):
725+
return cast("Array", self.rec(expr))
726+
elif isinstance(expr, AbstractResultWithNamedArrays):
727+
return cast("AbstractResultWithNamedArrays", self.rec(expr))
728+
else:
729+
return self.rec_function_definition(expr)
730+
705731
# }}}
706732

707733

@@ -1116,7 +1142,7 @@ def map_named_call_result(self, expr: NamedCallResult,
11161142
# }}}
11171143

11181144

1119-
# {{{ Deduplicator
1145+
# {{{ deduplicate
11201146

11211147
class Deduplicator(CopyMapper):
11221148
"""Removes duplicate nodes from an expression."""
@@ -1135,6 +1161,20 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self:
11351161
_function_cache=cast(
11361162
"TransformMapperCache[FunctionDefinition, []]", self._function_cache))
11371163

1164+
1165+
def deduplicate(
1166+
expr: ArrayOrNamesOrFunctionDefTc
1167+
) -> ArrayOrNamesOrFunctionDefTc:
1168+
"""
1169+
Remove duplicate nodes from an expression.
1170+
1171+
.. note::
1172+
Does not remove distinct instances of data wrappers that point to the same
1173+
data (as they will not hash the same). For a utility that does that, see
1174+
:func:`deduplicate_data_wrappers`.
1175+
"""
1176+
return Deduplicator()(expr)
1177+
11381178
# }}}
11391179

11401180

@@ -2061,9 +2101,9 @@ def get_dependencies(expr: DictOfNamedArrays) -> dict[str, frozenset[Array]]:
20612101
return {name: dep_mapper(val.expr) for name, val in expr.items()}
20622102

20632103

2064-
def map_and_copy(expr: MappedT,
2104+
def map_and_copy(expr: ArrayOrNamesTc,
20652105
map_fn: Callable[[ArrayOrNames], ArrayOrNames]
2066-
) -> MappedT:
2106+
) -> ArrayOrNamesTc:
20672107
"""
20682108
Returns a copy of *expr* with every array expression reachable from *expr*
20692109
mapped via *map_fn*.
@@ -2073,10 +2113,10 @@ def map_and_copy(expr: MappedT,
20732113
Uses :class:`CachedMapAndCopyMapper` under the hood and because of its
20742114
caching nature each node is mapped exactly once.
20752115
"""
2076-
return cast("MappedT", CachedMapAndCopyMapper(map_fn)(expr))
2116+
return CachedMapAndCopyMapper(map_fn)(expr)
20772117

20782118

2079-
def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
2119+
def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
20802120
r"""
20812121
Materialize nodes in *expr* with MPMS materialization strategy.
20822122
MPMS stands for Multiple-Predecessors, Multiple-Successors.
@@ -2128,11 +2168,18 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
21282168
"""
21292169
from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers
21302170
materializer = MPMSMaterializer(get_nusers(expr))
2131-
new_data = {}
2132-
for name, ary in expr.items():
2133-
new_data[name] = materializer(ary.expr).expr
21342171

2135-
res = DictOfNamedArrays(new_data, tags=expr.tags)
2172+
if isinstance(expr, Array):
2173+
res = materializer(expr).expr
2174+
assert isinstance(res, Array)
2175+
elif isinstance(expr, DictOfNamedArrays):
2176+
res = expr.replace_if_different(
2177+
data={
2178+
name: _verify_is_array(materializer(ary).expr)
2179+
for name, ary, in expr._data.items()})
2180+
assert isinstance(res, DictOfNamedArrays)
2181+
else:
2182+
raise NotImplementedError("not implemented for {type(expr).__name__}.")
21362183

21372184
from pytato import DEBUG_ENABLED
21382185
if DEBUG_ENABLED:
@@ -2408,7 +2455,9 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self:
24082455
"TransformMapperCache[FunctionDefinition, []]", self._function_cache))
24092456

24102457

2411-
def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames:
2458+
def deduplicate_data_wrappers(
2459+
array_or_names: ArrayOrNamesOrFunctionDefTc
2460+
) -> ArrayOrNamesOrFunctionDefTc:
24122461
"""For the expression graph given as *array_or_names*, replace all
24132462
:class:`pytato.array.DataWrapper` instances containing identical data
24142463
with a single instance.

0 commit comments

Comments
 (0)