Skip to content

Commit cdbb9cb

Browse files
majosminducer
authored andcommitted
use constantdict instead of immutabledict
1 parent 60fa10d commit cdbb9cb

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, cast
4040

4141
import numpy as np
42-
from immutabledict import immutabledict
42+
from constantdict import constantdict
4343

4444
import pytato as pt
4545
from pytools import ProcessLogger, to_identifier
@@ -167,7 +167,7 @@ def id_collector(keys, ary):
167167
" either a scalar, pt.Array or an array container. Got"
168168
f" '{arg}'.")
169169

170-
return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr)
170+
return constantdict(arg_id_to_arg), constantdict(arg_id_to_descr)
171171

172172

173173
def _to_input_for_compiled(

arraycontext/impl/pytato/outline.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from typing import TYPE_CHECKING, Generic, TypeVar, cast
3434

3535
import numpy as np
36-
from immutabledict import immutabledict
36+
from constantdict import constantdict
3737

3838
import pytato as pt
3939

@@ -62,14 +62,14 @@
6262
def _get_arg_id_to_arg(
6363
args: tuple[ArrayOrContainerOrScalar | None, ...],
6464
kwargs: Mapping[str, ArrayOrContainerOrScalar | None]
65-
) -> immutabledict[tuple[SerializationKey, ...], pt.Array]:
65+
) -> constantdict[tuple[SerializationKey, ...], pt.Array]:
6666
"""
6767
Helper for :meth:`OutlinedCall.__call__`. Extracts mappings from argument id
6868
to argument values. See
6969
:attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
7070
representation.
7171
"""
72-
arg_id_to_arg: dict[tuple[SerializationKey, ...], object] = {}
72+
arg_id_to_arg: dict[tuple[SerializationKey, ...], pt.Array] = {}
7373

7474
for kw, arg in itertools.chain(enumerate(args),
7575
kwargs.items()):
@@ -86,6 +86,7 @@ def id_collector(
8686
if is_scalar_like(ary):
8787
pass
8888
else:
89+
assert isinstance(ary, pt.Array)
8990
arg_id = (kw, *keys) # noqa: B023
9091
arg_id_to_arg[arg_id] = ary
9192
return ary
@@ -99,7 +100,7 @@ def id_collector(
99100
" either a scalar, pt.Array or an array container. Got"
100101
f" '{arg}'.")
101102

102-
return immutabledict(arg_id_to_arg)
103+
return constantdict(arg_id_to_arg)
103104

104105

105106
def _get_input_arg_id_str(
@@ -118,14 +119,14 @@ def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
118119
def _get_arg_id_to_placeholder(
119120
arg_id_to_arg: Mapping[tuple[SerializationKey, ...], pt.Array],
120121
prefix: str | None = None
121-
) -> immutabledict[tuple[SerializationKey, ...], pt.Placeholder]:
122+
) -> constantdict[tuple[SerializationKey, ...], pt.Placeholder]:
122123
"""
123124
Helper for :meth:`OutlinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
124125
for each argument in *arg_id_to_arg*. See
125126
:attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
126127
representation.
127128
"""
128-
return immutabledict({
129+
return constantdict({
129130
arg_id: pt.make_placeholder(
130131
_get_input_arg_id_str(arg_id, prefix=prefix),
131132
arg.shape,
@@ -174,31 +175,32 @@ def _rec_to_placeholder(
174175

175176

176177
def _unpack_output(
177-
output: ArrayOrContainerOrScalar) -> immutabledict[str, pt.Array]:
178+
output: ArrayOrContainerOrScalar) -> constantdict[str, pt.Array]:
178179
"""Unpack any array containers in *output*."""
179180
if isinstance(output, pt.Array):
180-
return immutabledict({"_": output})
181+
return constantdict({"_": output})
181182
elif is_array_container_type(output.__class__):
182-
unpacked_output = {}
183+
unpacked_output: dict[str, pt.Array] = {}
183184

184185
def _unpack_container(
185186
key: tuple[SerializationKey, ...],
186187
ary: ArrayOrScalar
187188
) -> ArrayOrScalar:
189+
assert isinstance(ary, pt.Array)
188190
key_str = _get_output_arg_id_str(key)
189191
unpacked_output[key_str] = ary
190192
return ary
191193

192194
rec_keyed_map_array_container(_unpack_container, output)
193195

194-
return immutabledict(unpacked_output)
196+
return constantdict(unpacked_output)
195197
else:
196198
raise NotImplementedError(type(output))
197199

198200

199201
def _pack_output(
200202
output_template: ArrayOrContainerOrScalar,
201-
unpacked_output: pt.Array | immutabledict[str, pt.Array]
203+
unpacked_output: pt.Array | constantdict[str, pt.Array]
202204
) -> ArrayOrContainerOrScalar:
203205
"""
204206
Pack *unpacked_output* into array containers according to *output_template*.
@@ -207,12 +209,12 @@ def _pack_output(
207209
assert isinstance(unpacked_output, pt.Array)
208210
return unpacked_output
209211
elif is_array_container_type(output_template.__class__):
210-
assert isinstance(unpacked_output, immutabledict)
212+
assert isinstance(unpacked_output, constantdict)
211213

212214
def _pack_into_container(
213215
key: tuple[SerializationKey, ...],
214216
ary: ArrayOrScalar # pyright: ignore[reportUnusedParameter]
215-
) -> ArrayOrScalar:
217+
) -> pt.Array:
216218
key_str = _get_output_arg_id_str(key)
217219
return unpacked_output[key_str]
218220

@@ -287,13 +289,13 @@ def __call__(self,
287289
func_def = pt.function.FunctionDefinition(
288290
parameters=frozenset(call_bindings.keys()),
289291
return_type=ret_type,
290-
returns=immutabledict(unpacked_output._data),
292+
returns=constantdict(unpacked_output._data),
291293
tags=self.tags,
292294
)
293295

294296
call_site_output = func_def(**call_bindings)
295297

296-
assert isinstance(call_site_output, pt.Array | immutabledict)
298+
assert isinstance(call_site_output, pt.Array | constantdict)
297299
return _pack_output(output, call_site_output)
298300

299301
# vim: foldmethod=marker

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
"Topic :: Utilities",
2828
]
2929
dependencies = [
30-
"immutabledict>=4.1",
30+
"constantdict",
3131
"numpy",
3232
"pytools>=2025.2.2",
3333
# for TypeIs

0 commit comments

Comments
 (0)