Skip to content

Commit 755e39a

Browse files
alexfiklinducer
authored andcommitted
mypy: fix errors from pytato typing improvements
1 parent fd59c90 commit 755e39a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

arraycontext/impl/pytato/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"""
2424

2525
from collections.abc import Mapping
26-
from typing import TYPE_CHECKING, Any
26+
from typing import TYPE_CHECKING, Any, cast
2727

2828
from pytato.array import (
2929
AbstractResultWithNamedArrays,
@@ -71,7 +71,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
7171
self.bound_arguments[name] = expr.data
7272
return make_placeholder(
7373
name=name,
74-
shape=tuple(self.rec(s) if isinstance(s, Array) else s
74+
shape=tuple(cast(Array, self.rec(s)) if isinstance(s, Array) else s
7575
for s in expr.shape),
7676
dtype=expr.dtype,
7777
axes=expr.axes,
@@ -87,7 +87,7 @@ def map_placeholder(self, expr: Placeholder) -> Array:
8787

8888
def _normalize_pt_expr(
8989
expr: DictOfNamedArrays
90-
) -> tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]:
90+
) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]:
9191
"""
9292
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
9393
normalized form of *expr*, with all instances of

0 commit comments

Comments
 (0)