22from __future__ import annotations
33
44
5- """
5+ __doc__ = """
66.. currentmodule:: arraycontext
7+
78.. autofunction:: with_container_arithmetic
89"""
910
10- import enum
11-
1211
1312__copyright__ = """
1413Copyright (C) 2020-1 University of Illinois Board of Trustees
3433THE SOFTWARE.
3534"""
3635
36+ import enum
3737from typing import Any , Callable , Optional , Tuple , TypeVar , Union
38+ from warnings import warn
3839
3940import numpy as np
4041
@@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
99100
100101
101102def _format_binary_op_str (op_str : str ,
102- arg1 : Union [Tuple [str , ... ], str ],
103- arg2 : Union [Tuple [str , ... ], str ]) -> str :
103+ arg1 : Union [Tuple [str , str ], str ],
104+ arg2 : Union [Tuple [str , str ], str ]) -> str :
104105 if isinstance (arg1 , tuple ) and isinstance (arg2 , tuple ):
105106 import sys
106107 if sys .version_info >= (3 , 10 ):
@@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str,
127128 return op_str .format (arg1 , arg2 )
128129
129130
131+ class NumpyObjectArrayMetaclass (type ):
132+ def __instancecheck__ (cls , instance : Any ) -> bool :
133+ return isinstance (instance , np .ndarray ) and instance .dtype == object
134+
135+
136+ class NumpyObjectArray (metaclass = NumpyObjectArrayMetaclass ):
137+ pass
138+
139+
140+ class ComplainingNumpyNonObjectArrayMetaclass (type ):
141+ def __instancecheck__ (cls , instance : Any ) -> bool :
142+ if isinstance (instance , np .ndarray ) and instance .dtype != object :
143+ # Example usage site:
144+ # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149
145+ # where normal is passed in by test_lfr_flux as a 'custom-made'
146+ # numpy array of dtype float64.
147+ warn (
148+ "Broadcasting container against non-object numpy array. "
149+ "This was never documented to work and will now stop working in "
150+ "2025. Convert the array to an object array to preserve the "
151+ "current semantics." , DeprecationWarning , stacklevel = 3 )
152+ return True
153+ else :
154+ return False
155+
156+
157+ class ComplainingNumpyNonObjectArray (metaclass = ComplainingNumpyNonObjectArrayMetaclass ):
158+ pass
159+
160+
130161def with_container_arithmetic (
131162 * ,
132163 bcast_number : bool = True ,
@@ -146,22 +177,16 @@ def with_container_arithmetic(
146177
147178 :arg bcast_number: If *True*, numbers broadcast over the container
148179 (with the container as the 'outer' structure).
149- :arg _bcast_actx_array_type: If *True*, instances of base array types of the
150- container's array context are broadcasted over the container. Can be
151- *True* only if the container has *_cls_has_array_context_attr* set.
152- Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set,
153- else *False*.
154- :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over
155- the container. (with the container as the 'inner' structure)
156- :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
157- over the container. (with the container as the 'inner' structure)
158- If this is set to *True*, *bcast_obj_array* must also be *True*.
180+ :arg bcast_obj_array: If *True*, this container will be broadcast
181+ across :mod:`numpy` object arrays
182+ (with the object array as the 'outer' structure).
183+ Add :class:`numpy.ndarray` to *bcast_container_types* to achieve
184+ the 'reverse' broadcasting.
159185 :arg bcast_container_types: A sequence of container types that will broadcast
160- over this container ( with this container as the 'outer' structure) .
186+ across this container, with this container as the 'outer' structure.
161187 :class:`numpy.ndarray` is permitted to be part of this sequence to
162- indicate that, in such broadcasting situations, this container should
163- be the 'outer' structure. In this case, *bcast_obj_array*
164- (and consequently *bcast_numpy_array*) must be *False*.
188+ indicate that object arrays (and *only* object arrays) will be broadcasat.
189+ In this case, *bcast_obj_array* must be *False*.
165190 :arg arithmetic: Implement the conventional arithmetic operators, including
166191 ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
167192 :func:`abs`.
@@ -203,6 +228,17 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
203228 should nest "outside" :func:dataclass_array_container`.
204229 """
205230
231+ # Hard-won design lessons:
232+ #
233+ # - Anything that special-cases np.ndarray by type is broken by design because:
234+ # - np.ndarray is an array context array.
235+ # - numpy object arrays can be array containers.
236+ # Using NumpyObjectArray and NumpyNonObjectArray *may* be better?
237+ # They're new, so there is no operational experience with them.
238+ #
239+ # - Broadcast rules are hard to change once established, particularly
240+ # because one cannot grep for their use.
241+
206242 # {{{ handle inputs
207243
208244 if bcast_obj_array is None :
@@ -212,9 +248,8 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
212248 raise TypeError ("rel_comparison must be specified" )
213249
214250 if bcast_numpy_array :
215- from warnings import warn
216251 warn ("'bcast_numpy_array=True' is deprecated and will be unsupported"
217- " from December 2021 " , DeprecationWarning , stacklevel = 2 )
252+ " from 2025. " , DeprecationWarning , stacklevel = 2 )
218253
219254 if _bcast_actx_array_type :
220255 raise ValueError ("'bcast_numpy_array' and '_bcast_actx_array_type'"
@@ -231,7 +266,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
231266
232267 if bcast_numpy_array :
233268 def numpy_pred (name : str ) -> str :
234- return f"isinstance ({ name } , np.ndarray )"
269+ return f"is_numpy_array ({ name } )"
235270 elif bcast_obj_array :
236271 def numpy_pred (name : str ) -> str :
237272 return f"isinstance({ name } , np.ndarray) and { name } .dtype.char == 'O'"
@@ -241,12 +276,21 @@ def numpy_pred(name: str) -> str:
241276
242277 if bcast_container_types is None :
243278 bcast_container_types = ()
244- bcast_container_types_count = len (bcast_container_types )
245279
246280 if np .ndarray in bcast_container_types and bcast_obj_array :
247281 raise ValueError ("If numpy.ndarray is part of bcast_container_types, "
248282 "bcast_obj_array must be False." )
249283
284+ numpy_check_types : list [type ] = [NumpyObjectArray , ComplainingNumpyNonObjectArray ]
285+ bcast_container_types = tuple (
286+ new_ct
287+ for old_ct in bcast_container_types
288+ for new_ct in
289+ (numpy_check_types
290+ if old_ct is np .ndarray
291+ else [old_ct ])
292+ )
293+
250294 desired_op_classes = set ()
251295 if arithmetic :
252296 desired_op_classes .add (_OpClass .ARITHMETIC )
@@ -264,19 +308,24 @@ def numpy_pred(name: str) -> str:
264308 # }}}
265309
266310 def wrap (cls : Any ) -> Any :
267- cls_has_array_context_attr : bool | None = \
268- _cls_has_array_context_attr
269- bcast_actx_array_type : bool | None = \
270- _bcast_actx_array_type
311+ if not hasattr (cls , "__array_ufunc__" ):
312+ warn (f"{ cls } does not have __array_ufunc__ set. "
313+ "This will cause numpy to attempt broadcasting, in a way that "
314+ "is likely undesired. "
315+ f"To avoid this, set __array_ufunc__ = None in { cls } ." ,
316+ stacklevel = 2 )
317+
318+ cls_has_array_context_attr : bool | None = _cls_has_array_context_attr
319+ bcast_actx_array_type : bool | None = _bcast_actx_array_type
271320
272321 if cls_has_array_context_attr is None :
273322 if hasattr (cls , "array_context" ):
274323 raise TypeError (
275324 f"{ cls } has an 'array_context' attribute, but it does not "
276325 "set '_cls_has_array_context_attr' to *True* when calling "
277326 "with_container_arithmetic. This is being interpreted "
278- "as 'array_context' being permitted to fail with an exception, "
279- "which is no longer allowed. "
327+ "as '. array_context' being permitted to fail "
328+ "with an exception, which is no longer allowed. "
280329 f"If { cls .__name__ } .array_context will not fail, pass "
281330 "'_cls_has_array_context_attr=True'. "
282331 "If you do not want container arithmetic to make "
@@ -294,6 +343,30 @@ def wrap(cls: Any) -> Any:
294343 raise TypeError ("_bcast_actx_array_type can be True only if "
295344 "_cls_has_array_context_attr is set." )
296345
346+ if bcast_actx_array_type :
347+ if _bcast_actx_array_type :
348+ warn (
349+ f"Broadcasting array context array types across { cls } "
350+ "has been explicitly "
351+ "enabled. As of 2025, this will stop working. "
352+ "There is no replacement as of right now. "
353+ "See the discussion in "
354+ "https://github.com/inducer/arraycontext/pull/190. "
355+ "To opt out now (and avoid this warning), "
356+ "pass _bcast_actx_array_type=False. " ,
357+ DeprecationWarning , stacklevel = 2 )
358+ else :
359+ warn (
360+ f"Broadcasting array context array types across { cls } "
361+ "has been implicitly "
362+ "enabled. As of 2025, this will no longer work. "
363+ "There is no replacement as of right now. "
364+ "See the discussion in "
365+ "https://github.com/inducer/arraycontext/pull/190. "
366+ "To opt out now (and avoid this warning), "
367+ "pass _bcast_actx_array_type=False." ,
368+ DeprecationWarning , stacklevel = 2 )
369+
297370 if (not hasattr (cls , "_serialize_init_arrays_code" )
298371 or not hasattr (cls , "_deserialize_init_arrays_code" )):
299372 raise TypeError (f"class '{ cls .__name__ } ' must provide serialization "
@@ -304,7 +377,7 @@ def wrap(cls: Any) -> Any:
304377
305378 from pytools .codegen import CodeGenerator , Indentation
306379 gen = CodeGenerator ()
307- gen ("""
380+ gen (f """
308381 from numbers import Number
309382 import numpy as np
310383 from arraycontext import ArrayContainer
@@ -315,6 +388,24 @@ def _raise_if_actx_none(actx):
315388 raise ValueError("array containers with frozen arrays "
316389 "cannot be operated upon")
317390 return actx
391+
392+ def is_numpy_array(arg):
393+ if isinstance(arg, np.ndarray):
394+ if arg.dtype != "O":
395+ warn("Operand is a non-object numpy array, "
396+ "and the broadcasting behavior of this array container "
397+ "({ cls } ) "
398+ "is influenced by this because of its use of "
399+ "the deprecated bcast_numpy_array. This broadcasting "
400+ "behavior will change in 2025. If you would like the "
401+ "broadcasting behavior to stay the same, make sure "
402+ "to convert the passed numpy array to an "
403+ "object array.",
404+ DeprecationWarning, stacklevel=3)
405+ return True
406+ else:
407+ return False
408+
318409 """ )
319410 gen ("" )
320411
@@ -323,7 +414,7 @@ def _raise_if_actx_none(actx):
323414 gen (f"from { bct .__module__ } import { bct .__qualname__ } as _bctype{ i } " )
324415 gen ("" )
325416 outer_bcast_type_names = tuple ([
326- f"_bctype{ i } " for i in range (bcast_container_types_count )
417+ f"_bctype{ i } " for i in range (len ( bcast_container_types ) )
327418 ])
328419 if bcast_number :
329420 outer_bcast_type_names += ("Number" ,)
@@ -384,20 +475,25 @@ def {fname}(arg1):
384475
385476 continue
386477
387- # {{{ "forward" binary operators
388-
389478 zip_init_args = cls ._deserialize_init_arrays_code ("arg1" , {
390479 same_key (key_arg1 , key_arg2 ):
391480 _format_binary_op_str (op_str , expr_arg1 , expr_arg2 )
392481 for (key_arg1 , expr_arg1 ), (key_arg2 , expr_arg2 ) in zip (
393482 cls ._serialize_init_arrays_code ("arg1" ).items (),
394483 cls ._serialize_init_arrays_code ("arg2" ).items ())
395484 })
396- bcast_same_cls_init_args = cls ._deserialize_init_arrays_code ("arg1" , {
485+ bcast_init_args_arg1_is_outer = cls ._deserialize_init_arrays_code ("arg1" , {
397486 key_arg1 : _format_binary_op_str (op_str , expr_arg1 , "arg2" )
398487 for key_arg1 , expr_arg1 in
399488 cls ._serialize_init_arrays_code ("arg1" ).items ()
400489 })
490+ bcast_init_args_arg2_is_outer = cls ._deserialize_init_arrays_code ("arg2" , {
491+ key_arg2 : _format_binary_op_str (op_str , "arg1" , expr_arg2 )
492+ for key_arg2 , expr_arg2 in
493+ cls ._serialize_init_arrays_code ("arg2" ).items ()
494+ })
495+
496+ # {{{ "forward" binary operators
401497
402498 gen (f"def { fname } (arg1, arg2):" )
403499 with Indentation (gen ):
@@ -424,7 +520,7 @@ def {fname}(arg1):
424520
425521 if bcast_actx_array_type :
426522 if __debug__ :
427- bcast_actx_ary_types = (
523+ bcast_actx_ary_types : tuple [ str , ...] = (
428524 "*_raise_if_actx_none("
429525 "arg1.array_context).array_types" ,)
430526 else :
@@ -444,7 +540,19 @@ def {fname}(arg1):
444540 if isinstance(arg2,
445541 { tup_str (outer_bcast_type_names
446542 + bcast_actx_ary_types )} ):
447- return cls({ bcast_same_cls_init_args } )
543+ if __debug__:
544+ if isinstance(arg2, { tup_str (bcast_actx_ary_types )} ):
545+ warn("Broadcasting { cls } over array "
546+ f"context array type {{type(arg2)}} is deprecated "
547+ "and will no longer work in 2025. "
548+ "There is no replacement as of right now. "
549+ "See the discussion in "
550+ "https://github.com/inducer/arraycontext/"
551+ "pull/190. ",
552+ DeprecationWarning, stacklevel=2)
553+
554+ return cls({ bcast_init_args_arg1_is_outer } )
555+
448556 return NotImplemented
449557 """ )
450558 gen (f"cls.__{ dunder_name } __ = { fname } " )
@@ -456,12 +564,6 @@ def {fname}(arg1):
456564
457565 if reversible :
458566 fname = f"_{ cls .__name__ .lower ()} _r{ dunder_name } "
459- bcast_init_args = cls ._deserialize_init_arrays_code ("arg2" , {
460- key_arg2 : _format_binary_op_str (
461- op_str , "arg1" , expr_arg2 )
462- for key_arg2 , expr_arg2 in
463- cls ._serialize_init_arrays_code ("arg2" ).items ()
464- })
465567
466568 if bcast_actx_array_type :
467569 if __debug__ :
@@ -487,7 +589,21 @@ def {fname}(arg2, arg1):
487589 if isinstance(arg1,
488590 { tup_str (outer_bcast_type_names
489591 + bcast_actx_ary_types )} ):
490- return cls({ bcast_init_args } )
592+ if __debug__:
593+ if isinstance(arg1,
594+ { tup_str (bcast_actx_ary_types )} ):
595+ warn("Broadcasting { cls } over array "
596+ f"context array type {{type(arg1)}} "
597+ "is deprecated "
598+ "and will no longer work in 2025."
599+ "There is no replacement as of right now. "
600+ "See the discussion in "
601+ "https://github.com/inducer/arraycontext/"
602+ "pull/190. ",
603+ DeprecationWarning, stacklevel=2)
604+
605+ return cls({ bcast_init_args_arg2_is_outer } )
606+
491607 return NotImplemented
492608
493609 cls.__r{ dunder_name } __ = { fname } """ )
0 commit comments