Skip to content

Commit 3f234cc

Browse files
committed
Misc typing fixes
1 parent d525bd4 commit 3f234cc

File tree

3 files changed

+67
-42
lines changed

3 files changed

+67
-42
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,11 @@ def __init__(
166166
"""
167167
super().__init__()
168168

169-
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
169+
self._freeze_prg_cache: dict[
170+
pt.AbstractResultWithNamedArrays, lp.TranslationUnit] = {}
170171
self._dag_transform_cache: dict[
171-
pt.DictOfNamedArrays,
172-
tuple[pt.DictOfNamedArrays, str]] = {}
172+
pt.AbstractResultWithNamedArrays,
173+
tuple[pt.AbstractResultWithNamedArrays, str]] = {}
173174

174175
if compile_trace_callback is None:
175176
def _compile_trace_callback(what, stage, ir):
@@ -229,8 +230,8 @@ def _tag_axis(ary: ArrayOrScalar) -> ArrayOrScalar:
229230

230231
# {{{ compilation
231232

232-
def transform_dag(self, dag: pytato.DictOfNamedArrays
233-
) -> pytato.DictOfNamedArrays:
233+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
234+
) -> pytato.AbstractResultWithNamedArrays:
234235
"""
235236
Returns a transformed version of *dag*. Sub-classes are supposed to
236237
override this method to implement context-specific transformations on
@@ -635,12 +636,10 @@ def _to_frozen(
635636
pt.make_dict_of_named_arrays(key_to_pt_arrays))
636637

637638
# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
638-
pt_dict_of_named_arrays = pt.tag_all_calls_to_be_inlined(
639-
pt_dict_of_named_arrays)
640-
pt_dict_of_named_arrays = pt.inline_calls(pt_dict_of_named_arrays)
639+
dag = pt.tag_all_calls_to_be_inlined(dag)
640+
dag = pt.inline_calls(dag)
641641

642-
normalized_expr, bound_arguments = _normalize_pt_expr(
643-
pt_dict_of_named_arrays)
642+
normalized_expr, bound_arguments = _normalize_pt_expr(dag)
644643

645644
try:
646645
pt_prg = self._freeze_prg_cache[normalized_expr]
@@ -786,8 +785,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
786785
from .compile import LazilyPyOpenCLCompilingFunctionCaller
787786
return LazilyPyOpenCLCompilingFunctionCaller(self, f)
788787

789-
def transform_dag(self, dag: pytato.DictOfNamedArrays
790-
) -> pytato.DictOfNamedArrays:
788+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
789+
) -> pytato.AbstractResultWithNamedArrays:
791790
import pytato as pt
792791
dag = pt.tag_all_calls_to_be_inlined(dag)
793792
dag = pt.inline_calls(dag)
@@ -986,8 +985,9 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
986985
from .compile import LazilyJAXCompilingFunctionCaller
987986
return LazilyJAXCompilingFunctionCaller(self, f)
988987

989-
def transform_dag(self, dag: pytato.DictOfNamedArrays
990-
) -> pytato.DictOfNamedArrays:
988+
@override
989+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
990+
) -> pytato.AbstractResultWithNamedArrays:
991991
import pytato as pt
992992
dag = pt.tag_all_calls_to_be_inlined(dag)
993993
dag = pt.inline_calls(dag)

arraycontext/impl/pytato/compile.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from collections.abc import Callable, Hashable, Mapping
6767

6868
import pyopencl.array as cla
69+
from pytato.array import AxesT
6970

7071
AllowedArray: TypeAlias = "pt.Array | TaggableCLArray | cla.Array"
7172
AllowedArrayTc = TypeVar("AllowedArrayTc", pt.Array, TaggableCLArray, "cla.Array")
@@ -408,12 +409,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
408409
self.actx._compile_trace_callback(
409410
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
410411

411-
name_in_program_to_tags = {
412-
name: out.tags
413-
for name, out in pt_dict_of_named_arrays._data.items()}
414-
name_in_program_to_axes = {
415-
name: out.axes
416-
for name, out in pt_dict_of_named_arrays._data.items()}
412+
name_in_program_to_tags: dict[str, frozenset[Tag]] = {}
413+
name_in_program_to_axes: dict[str, AxesT] = {}
414+
if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays):
415+
name_in_program_to_tags.update({
416+
name: out.tags
417+
for name, out in pt_dict_of_named_arrays._data.items()})
418+
419+
name_in_program_to_axes.update({
420+
name: out.axes
421+
for name, out in pt_dict_of_named_arrays._data.items()})
417422

418423
self.actx._compile_trace_callback(
419424
prg_id, "pre_generate_loopy", pt_dict_of_named_arrays)
@@ -505,12 +510,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
505510
self.actx._compile_trace_callback(
506511
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
507512

508-
name_in_program_to_tags = {
509-
name: out.tags
510-
for name, out in pt_dict_of_named_arrays._data.items()}
511-
name_in_program_to_axes = {
512-
name: out.axes
513-
for name, out in pt_dict_of_named_arrays._data.items()}
513+
name_in_program_to_tags: dict[str, frozenset[Tag]] = {}
514+
name_in_program_to_axes: dict[str, AxesT] = {}
515+
if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays):
516+
name_in_program_to_tags.update({
517+
name: out.tags
518+
for name, out in pt_dict_of_named_arrays._data.items()})
519+
520+
name_in_program_to_axes.update({
521+
name: out.axes
522+
for name, out in pt_dict_of_named_arrays._data.items()})
514523

515524
self.actx._compile_trace_callback(
516525
prg_id, "pre_generate_jax", pt_dict_of_named_arrays)

arraycontext/impl/pytato/utils.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
^^^^^^^^^^^^^^^^^^^^^^^^^^^
1111
1212
.. autofunction:: tabulate_profiling_data
13+
14+
References
15+
^^^^^^^^^^
16+
17+
.. autoclass:: ArrayOrNamesTc
18+
19+
A constrained type variable binding to either
20+
:class:`pytato.Array` or :class:`pytato.AbstractResultWithNames`.
1321
"""
1422

1523

@@ -38,24 +46,25 @@
3846
"""
3947

4048

41-
from typing import TYPE_CHECKING, Any, cast
49+
from typing import TYPE_CHECKING, cast
50+
51+
from typing_extensions import override
4252

4353
import pytools
4454
from pytato.analysis import get_num_call_sites
4555
from pytato.array import (
46-
AbstractResultWithNamedArrays,
4756
Array,
4857
Axis as PtAxis,
58+
DataInterface,
4959
DataWrapper,
50-
DictOfNamedArrays,
5160
Placeholder,
5261
SizeParam,
5362
make_placeholder,
5463
)
55-
from pytato.function import FunctionDefinition
5664
from pytato.target.loopy import LoopyPyOpenCLTarget
5765
from pytato.transform import (
5866
ArrayOrNames,
67+
ArrayOrNamesTc,
5968
CopyMapper,
6069
TransformMapperCache,
6170
deduplicate,
@@ -69,6 +78,8 @@
6978
from collections.abc import Mapping
7079

7180
import loopy as lp
81+
from pytato import AbstractResultWithNamedArrays
82+
from pytato.function import FunctionDefinition
7283

7384
from arraycontext import ArrayContext
7485
from arraycontext.container import SerializationKey
@@ -94,10 +105,11 @@ def __init__(
94105
_cache=_cache,
95106
_function_cache=_function_cache)
96107

97-
self.bound_arguments: dict[str, Any] = {}
98-
self.vng = UniqueNameGenerator()
108+
self.bound_arguments: dict[str, DataInterface] = {}
109+
self.vng: UniqueNameGenerator = UniqueNameGenerator()
99110
self.seen_inputs: set[str] = set()
100111

112+
@override
101113
def map_data_wrapper(self, expr: DataWrapper) -> Array:
102114
if expr.name is not None:
103115
if expr.name in self.seen_inputs:
@@ -119,13 +131,16 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
119131
axes=expr.axes,
120132
tags=expr.tags)
121133

134+
@override
122135
def map_size_param(self, expr: SizeParam) -> Array:
123136
raise NotImplementedError
124137

138+
@override
125139
def map_placeholder(self, expr: Placeholder) -> Array:
126140
raise ValueError("Placeholders cannot appear in"
127141
" DatawrapperToBoundPlaceholderMapper.")
128142

143+
@override
129144
def map_function_definition(
130145
self, expr: FunctionDefinition) -> FunctionDefinition:
131146
raise ValueError("Function definitions cannot appear in"
@@ -135,8 +150,9 @@ def map_function_definition(
135150
# FIXME: This strategy doesn't work if the DAG has functions, since function
136151
# definitions can't contain non-argument placeholders
137152
def _normalize_pt_expr(
138-
expr: DictOfNamedArrays
139-
) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]:
153+
expr: AbstractResultWithNamedArrays
154+
) -> tuple[AbstractResultWithNamedArrays,
155+
Mapping[str, DataInterface]]:
140156
"""
141157
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
142158
normalized form of *expr*, with all instances of
@@ -155,7 +171,6 @@ def _normalize_pt_expr(
155171

156172
normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
157173
normalized_expr = normalize_mapper(expr)
158-
assert isinstance(normalized_expr, AbstractResultWithNamedArrays)
159174
return normalized_expr, normalize_mapper.bound_arguments
160175

161176

@@ -172,7 +187,7 @@ def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]:
172187
class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
173188
def __init__(self, limit_arg_size_nbytes: int) -> None:
174189
super().__init__()
175-
self.limit_arg_size_nbytes = limit_arg_size_nbytes
190+
self.limit_arg_size_nbytes: int = limit_arg_size_nbytes
176191

177192
@memoize_method
178193
def get_loopy_target(self) -> lp.PyOpenCLTarget:
@@ -191,8 +206,9 @@ class TransferFromNumpyMapper(CopyMapper):
191206
"""
192207
def __init__(self, actx: ArrayContext) -> None:
193208
super().__init__()
194-
self.actx = actx
209+
self.actx: ArrayContext = actx
195210

211+
@override
196212
def map_data_wrapper(self, expr: DataWrapper) -> Array:
197213
import numpy as np
198214

@@ -223,8 +239,9 @@ class TransferToNumpyMapper(CopyMapper):
223239
"""
224240
def __init__(self, actx: ArrayContext) -> None:
225241
super().__init__()
226-
self.actx = actx
242+
self.actx: ArrayContext = actx
227243

244+
@override
228245
def map_data_wrapper(self, expr: DataWrapper) -> Array:
229246
import numpy as np
230247

@@ -244,15 +261,15 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
244261
non_equality_tags=expr.non_equality_tags)
245262

246263

247-
def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
264+
def transfer_from_numpy(expr: ArrayOrNamesTc, actx: ArrayContext) -> ArrayOrNamesTc:
248265
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
249266
instances to be device arrays, using
250267
:meth:`~arraycontext.ArrayContext.from_numpy`.
251268
"""
252269
return TransferFromNumpyMapper(actx)(expr)
253270

254271

255-
def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
272+
def transfer_to_numpy(expr: ArrayOrNamesTc, actx: ArrayContext) -> ArrayOrNamesTc:
256273
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
257274
instances to be :class:`numpy.ndarray` instances, using
258275
:meth:`~arraycontext.ArrayContext.to_numpy`.
@@ -285,8 +302,7 @@ def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table:
285302

286303
t_sum = sum(times)
287304
t_avg = t_sum / num_calls
288-
if t_sum is not None:
289-
total_time += t_sum
305+
total_time += t_sum
290306

291307
tbl.add_row((kernel_name, num_calls, f"{t_sum:{g}}", f"{t_avg:{g}}"))
292308

0 commit comments

Comments
 (0)