7070THE SOFTWARE.
7171"""
7272
73- from collections .abc import Callable , Iterable
7473from functools import partial , singledispatch , update_wrapper
75- from typing import Any , cast
74+ from typing import TYPE_CHECKING , Any , cast
7675from warnings import warn
7776
7877import numpy as np
8786 get_container_context_recursively_opt ,
8887 serialize_container ,
8988)
90- from arraycontext .context import (
91- Array ,
92- ArrayContext ,
93- ArrayOrContainer ,
94- ArrayOrContainerOrScalar ,
95- ArrayOrContainerT ,
96- ScalarLike ,
97- )
89+
90+
91+ if TYPE_CHECKING :
92+ from collections .abc import Callable , Iterable
93+
94+ from arraycontext .context import (
95+ Array ,
96+ ArrayContext ,
97+ ArrayOrContainer ,
98+ ArrayOrContainerOrScalar ,
99+ ArrayOrContainerT ,
100+ ScalarLike ,
101+ )
98102
99103
100104# {{{ array container traversal helpers
@@ -414,7 +418,7 @@ def rec(keys: tuple[SerializationKey, ...],
414418 try :
415419 iterable = serialize_container (ary_ )
416420 except NotAnArrayContainerError :
417- return cast (ArrayOrContainer , f (keys , cast (Array , ary_ )))
421+ return cast (" ArrayOrContainer" , f (keys , cast (" Array" , ary_ )))
418422 else :
419423 return deserialize_container (ary_ , [
420424 (key , rec ((* keys , key ), subary )) for key , subary in iterable
@@ -699,7 +703,7 @@ def _flatten(subary: ArrayOrContainer) -> list[Array]:
699703 try :
700704 iterable = serialize_container (subary )
701705 except NotAnArrayContainerError :
702- subary_c = cast (Array , subary )
706+ subary_c = cast (" Array" , subary )
703707
704708 if common_dtype is None :
705709 common_dtype = subary_c .dtype
@@ -786,7 +790,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
786790 try :
787791 iterable = serialize_container (template_subary )
788792 except NotAnArrayContainerError :
789- template_subary_c = cast (Array , template_subary )
793+ template_subary_c = cast (" Array" , template_subary )
790794
791795 # {{{ validate subary
792796
@@ -877,7 +881,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
877881 raise ValueError ("'template' and 'ary' sizes do not match: "
878882 "'ary' is too large" )
879883
880- return cast (ArrayOrContainerT , result )
884+ return cast (" ArrayOrContainerT" , result )
881885
882886
883887def flat_size_and_dtype (
@@ -895,7 +899,7 @@ def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
895899 try :
896900 iterable = serialize_container (subary )
897901 except NotAnArrayContainerError :
898- subary_c = cast (Array , subary )
902+ subary_c = cast (" Array" , subary )
899903
900904 if common_dtype is None :
901905 common_dtype = subary_c .dtype
0 commit comments