3333from typing import TYPE_CHECKING , Generic , TypeVar , cast
3434
3535import numpy as np
36- from immutabledict import immutabledict
36+ from constantdict import constantdict
3737
3838import pytato as pt
3939
6262def _get_arg_id_to_arg (
6363 args : tuple [ArrayOrContainerOrScalar | None , ...],
6464 kwargs : Mapping [str , ArrayOrContainerOrScalar | None ]
65- ) -> immutabledict [tuple [SerializationKey , ...], pt .Array ]:
65+ ) -> constantdict [tuple [SerializationKey , ...], pt .Array ]:
6666 """
6767 Helper for :meth:`OutlinedCall.__call__`. Extracts mappings from argument id
6868 to argument values. See
6969 :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
7070 representation.
7171 """
72- arg_id_to_arg : dict [tuple [SerializationKey , ...], object ] = {}
72+ arg_id_to_arg : dict [tuple [SerializationKey , ...], pt . Array ] = {}
7373
7474 for kw , arg in itertools .chain (enumerate (args ),
7575 kwargs .items ()):
@@ -86,6 +86,7 @@ def id_collector(
8686 if is_scalar_like (ary ):
8787 pass
8888 else :
89+ assert isinstance (ary , pt .Array )
8990 arg_id = (kw , * keys ) # noqa: B023
9091 arg_id_to_arg [arg_id ] = ary
9192 return ary
@@ -99,7 +100,7 @@ def id_collector(
99100 " either a scalar, pt.Array or an array container. Got"
100101 f" '{ arg } '." )
101102
102- return immutabledict (arg_id_to_arg )
103+ return constantdict (arg_id_to_arg )
103104
104105
105106def _get_input_arg_id_str (
@@ -118,14 +119,14 @@ def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
118119def _get_arg_id_to_placeholder (
119120 arg_id_to_arg : Mapping [tuple [SerializationKey , ...], pt .Array ],
120121 prefix : str | None = None
121- ) -> immutabledict [tuple [SerializationKey , ...], pt .Placeholder ]:
122+ ) -> constantdict [tuple [SerializationKey , ...], pt .Placeholder ]:
122123 """
123124 Helper for :meth:`OutlinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
124125 for each argument in *arg_id_to_arg*. See
125126 :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
126127 representation.
127128 """
128- return immutabledict ({
129+ return constantdict ({
129130 arg_id : pt .make_placeholder (
130131 _get_input_arg_id_str (arg_id , prefix = prefix ),
131132 arg .shape ,
@@ -174,31 +175,32 @@ def _rec_to_placeholder(
174175
175176
176177def _unpack_output (
177- output : ArrayOrContainerOrScalar ) -> immutabledict [str , pt .Array ]:
178+ output : ArrayOrContainerOrScalar ) -> constantdict [str , pt .Array ]:
178179 """Unpack any array containers in *output*."""
179180 if isinstance (output , pt .Array ):
180- return immutabledict ({"_" : output })
181+ return constantdict ({"_" : output })
181182 elif is_array_container_type (output .__class__ ):
182- unpacked_output = {}
183+ unpacked_output : dict [ str , pt . Array ] = {}
183184
184185 def _unpack_container (
185186 key : tuple [SerializationKey , ...],
186187 ary : ArrayOrScalar
187188 ) -> ArrayOrScalar :
189+ assert isinstance (ary , pt .Array )
188190 key_str = _get_output_arg_id_str (key )
189191 unpacked_output [key_str ] = ary
190192 return ary
191193
192194 rec_keyed_map_array_container (_unpack_container , output )
193195
194- return immutabledict (unpacked_output )
196+ return constantdict (unpacked_output )
195197 else :
196198 raise NotImplementedError (type (output ))
197199
198200
199201def _pack_output (
200202 output_template : ArrayOrContainerOrScalar ,
201- unpacked_output : pt .Array | immutabledict [str , pt .Array ]
203+ unpacked_output : pt .Array | constantdict [str , pt .Array ]
202204 ) -> ArrayOrContainerOrScalar :
203205 """
204206 Pack *unpacked_output* into array containers according to *output_template*.
@@ -207,12 +209,12 @@ def _pack_output(
207209 assert isinstance (unpacked_output , pt .Array )
208210 return unpacked_output
209211 elif is_array_container_type (output_template .__class__ ):
210- assert isinstance (unpacked_output , immutabledict )
212+ assert isinstance (unpacked_output , constantdict )
211213
212214 def _pack_into_container (
213215 key : tuple [SerializationKey , ...],
214216 ary : ArrayOrScalar # pyright: ignore[reportUnusedParameter]
215- ) -> ArrayOrScalar :
217+ ) -> pt . Array :
216218 key_str = _get_output_arg_id_str (key )
217219 return unpacked_output [key_str ]
218220
@@ -287,13 +289,13 @@ def __call__(self,
287289 func_def = pt .function .FunctionDefinition (
288290 parameters = frozenset (call_bindings .keys ()),
289291 return_type = ret_type ,
290- returns = immutabledict (unpacked_output ._data ),
292+ returns = constantdict (unpacked_output ._data ),
291293 tags = self .tags ,
292294 )
293295
294296 call_site_output = func_def (** call_bindings )
295297
296- assert isinstance (call_site_output , pt .Array | immutabledict )
298+ assert isinstance (call_site_output , pt .Array | constantdict )
297299 return _pack_output (output , call_site_output )
298300
299301# vim: foldmethod=marker
0 commit comments