1010^^^^^^^^^^^^^^^^^^^^^^^^^^^
1111
1212.. autofunction:: tabulate_profiling_data
13+
14+ References
15+ ^^^^^^^^^^
16+
17+ .. autoclass:: ArrayOrNamesTc
18+
19+ A constrained type variable binding to either
20+ :class:`pytato.Array` or :class:`pytato.AbstractResultWithNames`.
1321"""
1422
1523
3846"""
3947
4048
41- from typing import TYPE_CHECKING , Any , cast
49+ from typing import TYPE_CHECKING , cast
50+
51+ from typing_extensions import override
4252
4353import pytools
4454from pytato .analysis import get_num_call_sites
4555from pytato .array import (
46- AbstractResultWithNamedArrays ,
4756 Array ,
4857 Axis as PtAxis ,
58+ DataInterface ,
4959 DataWrapper ,
50- DictOfNamedArrays ,
5160 Placeholder ,
5261 SizeParam ,
5362 make_placeholder ,
5463)
55- from pytato .function import FunctionDefinition
5664from pytato .target .loopy import LoopyPyOpenCLTarget
5765from pytato .transform import (
5866 ArrayOrNames ,
67+ ArrayOrNamesTc ,
5968 CopyMapper ,
6069 TransformMapperCache ,
6170 deduplicate ,
6978 from collections .abc import Mapping
7079
7180 import loopy as lp
81+ from pytato import AbstractResultWithNamedArrays
82+ from pytato .function import FunctionDefinition
7283
7384 from arraycontext import ArrayContext
7485 from arraycontext .container import SerializationKey
@@ -94,10 +105,11 @@ def __init__(
94105 _cache = _cache ,
95106 _function_cache = _function_cache )
96107
97- self .bound_arguments : dict [str , Any ] = {}
98- self .vng = UniqueNameGenerator ()
108+ self .bound_arguments : dict [str , DataInterface ] = {}
109+ self .vng : UniqueNameGenerator = UniqueNameGenerator ()
99110 self .seen_inputs : set [str ] = set ()
100111
112+ @override
101113 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
102114 if expr .name is not None :
103115 if expr .name in self .seen_inputs :
@@ -119,13 +131,16 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
119131 axes = expr .axes ,
120132 tags = expr .tags )
121133
134+ @override
122135 def map_size_param (self , expr : SizeParam ) -> Array :
123136 raise NotImplementedError
124137
138+ @override
125139 def map_placeholder (self , expr : Placeholder ) -> Array :
126140 raise ValueError ("Placeholders cannot appear in"
127141 " DatawrapperToBoundPlaceholderMapper." )
128142
143+ @override
129144 def map_function_definition (
130145 self , expr : FunctionDefinition ) -> FunctionDefinition :
131146 raise ValueError ("Function definitions cannot appear in"
@@ -135,8 +150,9 @@ def map_function_definition(
135150# FIXME: This strategy doesn't work if the DAG has functions, since function
136151# definitions can't contain non-argument placeholders
137152def _normalize_pt_expr (
138- expr : DictOfNamedArrays
139- ) -> tuple [Array | AbstractResultWithNamedArrays , Mapping [str , Any ]]:
153+ expr : AbstractResultWithNamedArrays
154+ ) -> tuple [AbstractResultWithNamedArrays ,
155+ Mapping [str , DataInterface ]]:
140156 """
141157 Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
142158 normalized form of *expr*, with all instances of
@@ -155,7 +171,6 @@ def _normalize_pt_expr(
155171
156172 normalize_mapper = _DatawrapperToBoundPlaceholderMapper ()
157173 normalized_expr = normalize_mapper (expr )
158- assert isinstance (normalized_expr , AbstractResultWithNamedArrays )
159174 return normalized_expr , normalize_mapper .bound_arguments
160175
161176
@@ -172,7 +187,7 @@ def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]:
172187class ArgSizeLimitingPytatoLoopyPyOpenCLTarget (LoopyPyOpenCLTarget ):
173188 def __init__ (self , limit_arg_size_nbytes : int ) -> None :
174189 super ().__init__ ()
175- self .limit_arg_size_nbytes = limit_arg_size_nbytes
190+ self .limit_arg_size_nbytes : int = limit_arg_size_nbytes
176191
177192 @memoize_method
178193 def get_loopy_target (self ) -> lp .PyOpenCLTarget :
@@ -191,8 +206,9 @@ class TransferFromNumpyMapper(CopyMapper):
191206 """
192207 def __init__ (self , actx : ArrayContext ) -> None :
193208 super ().__init__ ()
194- self .actx = actx
209+ self .actx : ArrayContext = actx
195210
211+ @override
196212 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
197213 import numpy as np
198214
@@ -223,8 +239,9 @@ class TransferToNumpyMapper(CopyMapper):
223239 """
224240 def __init__ (self , actx : ArrayContext ) -> None :
225241 super ().__init__ ()
226- self .actx = actx
242+ self .actx : ArrayContext = actx
227243
244+ @override
228245 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
229246 import numpy as np
230247
@@ -244,15 +261,15 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
244261 non_equality_tags = expr .non_equality_tags )
245262
246263
247- def transfer_from_numpy (expr : ArrayOrNames , actx : ArrayContext ) -> ArrayOrNames :
264+ def transfer_from_numpy (expr : ArrayOrNamesTc , actx : ArrayContext ) -> ArrayOrNamesTc :
248265 """Transfer arrays contained in :class:`~pytato.array.DataWrapper`
249266 instances to be device arrays, using
250267 :meth:`~arraycontext.ArrayContext.from_numpy`.
251268 """
252269 return TransferFromNumpyMapper (actx )(expr )
253270
254271
255- def transfer_to_numpy (expr : ArrayOrNames , actx : ArrayContext ) -> ArrayOrNames :
272+ def transfer_to_numpy (expr : ArrayOrNamesTc , actx : ArrayContext ) -> ArrayOrNamesTc :
256273 """Transfer arrays contained in :class:`~pytato.array.DataWrapper`
257274 instances to be :class:`numpy.ndarray` instances, using
258275 :meth:`~arraycontext.ArrayContext.to_numpy`.
@@ -285,8 +302,7 @@ def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table:
285302
286303 t_sum = sum (times )
287304 t_avg = t_sum / num_calls
288- if t_sum is not None :
289- total_time += t_sum
305+ total_time += t_sum
290306
291307 tbl .add_row ((kernel_name , num_calls , f"{ t_sum :{g }} " , f"{ t_avg :{g }} " ))
292308
0 commit comments