6969
7070from arraycontext .context import ArrayContext
7171from arraycontext .container import (
72- ContainerT , ArrayOrContainerT ,
72+ ContainerT , ArrayOrContainerT , NotAnArrayContainerError ,
7373 serialize_container , deserialize_container )
7474
7575
@@ -93,7 +93,7 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
9393
9494 try :
9595 iterable = serialize_container (_ary )
96- except TypeError :
96+ except NotAnArrayContainerError :
9797 return f (_ary )
9898 else :
9999 return deserialize_container (_ary , [
@@ -127,7 +127,7 @@ def rec(*_args: Any) -> Any:
127127
128128 try :
129129 iterable_template = serialize_container (template_ary )
130- except TypeError :
130+ except NotAnArrayContainerError :
131131 return f (* _args )
132132 else :
133133 pass
@@ -170,7 +170,7 @@ def rec(*_args: Any) -> Any:
170170 # FIXME: this will serialize again once `rec` is called, which is
171171 # not great, but it doesn't seem like there's a good way to avoid it
172172 _ = serialize_container (arg )
173- except TypeError :
173+ except NotAnArrayContainerError :
174174 pass
175175 else :
176176 container_indices .append (i )
@@ -231,7 +231,7 @@ def map_array_container(
231231 """
232232 try :
233233 iterable = serialize_container (ary )
234- except TypeError :
234+ except NotAnArrayContainerError :
235235 return f (ary )
236236 else :
237237 return deserialize_container (ary , [
@@ -316,7 +316,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any],
316316 """
317317 try :
318318 iterable = serialize_container (ary )
319- except TypeError :
319+ except NotAnArrayContainerError :
320320 raise ValueError (
321321 f"Non-array container type has no key: { type (ary ).__name__ } " )
322322 else :
@@ -338,7 +338,7 @@ def rec(keys: Tuple[Union[str, int], ...],
338338 _ary : ArrayOrContainerT ) -> ArrayOrContainerT :
339339 try :
340340 iterable = serialize_container (_ary )
341- except TypeError :
341+ except NotAnArrayContainerError :
342342 return f (keys , _ary )
343343 else :
344344 return deserialize_container (_ary , [
@@ -367,7 +367,7 @@ def map_reduce_array_container(
367367 """
368368 try :
369369 iterable = serialize_container (ary )
370- except TypeError :
370+ except NotAnArrayContainerError :
371371 return map_func (ary )
372372 else :
373373 return reduce_func ([
@@ -442,7 +442,7 @@ def rec_map_reduce_array_container(
442442 def rec (_ary : ArrayOrContainerT ) -> ArrayOrContainerT :
443443 try :
444444 iterable = serialize_container (_ary )
445- except TypeError :
445+ except NotAnArrayContainerError :
446446 return map_func (_ary )
447447 else :
448448 return reduce_func ([
@@ -501,7 +501,7 @@ def freeze(
501501 """
502502 try :
503503 iterable = serialize_container (ary )
504- except TypeError :
504+ except NotAnArrayContainerError :
505505 if actx is None :
506506 raise TypeError (
507507 f"cannot freeze arrays of type { type (ary ).__name__ } "
@@ -538,7 +538,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
538538 """
539539 try :
540540 iterable = serialize_container (ary )
541- except TypeError :
541+ except NotAnArrayContainerError :
542542 return actx .thaw (ary )
543543 else :
544544 return deserialize_container (ary , [
@@ -567,7 +567,7 @@ def _flatten(subary: ArrayOrContainerT) -> None:
567567
568568 try :
569569 iterable = serialize_container (subary )
570- except TypeError :
570+ except NotAnArrayContainerError :
571571 if common_dtype is None :
572572 common_dtype = subary .dtype
573573
@@ -618,7 +618,7 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
618618
619619 try :
620620 iterable = serialize_container (template_subary )
621- except TypeError :
621+ except NotAnArrayContainerError :
622622 if (offset + template_subary .size ) > ary .size :
623623 raise ValueError ("'template' and 'ary' sizes do not match: "
624624 "'template' is too large" )
@@ -682,9 +682,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any:
682682 The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
683683 """
684684 def _from_numpy_with_check (subary : Any ) -> Any :
685- if np .isscalar (subary ):
686- return subary
687- elif isinstance (subary , np .ndarray ):
685+ if isinstance (subary , np .ndarray ) or np .isscalar (subary ):
688686 return actx .from_numpy (subary )
689687 else :
690688 raise TypeError (f"array is not an ndarray: '{ type (subary ).__name__ } '" )
@@ -699,9 +697,7 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any:
699697 The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`.
700698 """
701699 def _to_numpy_with_check (subary : Any ) -> Any :
702- if np .isscalar (subary ):
703- return subary
704- elif isinstance (subary , actx .array_types ):
700+ if isinstance (subary , actx .array_types ) or np .isscalar (subary ):
705701 return actx .to_numpy (subary )
706702 else :
707703 raise TypeError (
@@ -734,7 +730,7 @@ def outer(a: Any, b: Any) -> Any:
734730 def treat_as_scalar (x : Any ) -> bool :
735731 try :
736732 serialize_container (x )
737- except TypeError :
733+ except NotAnArrayContainerError :
738734 return True
739735 else :
740736 return (
0 commit comments