Skip to content

Commit 9f05f72

Browse files
alexfiklinducer
authored andcommitted
add container tests to test_array_context_np_like
1 parent 5c9d57a commit 9f05f72

File tree

2 files changed

+103
-73
lines changed

2 files changed

+103
-73
lines changed

arraycontext/impl/jax/fake_numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ def _rec_vdot(ary1, ary2):
126126

127127
# {{{ logic functions
128128

129+
def all(self, a):
130+
return rec_map_reduce_array_container(
131+
partial(reduce, jnp.logical_and), jnp.all, a)
132+
133+
def any(self, a):
134+
return rec_map_reduce_array_container(
135+
partial(reduce, jnp.logical_or), jnp.any, a)
136+
129137
def array_equal(self, a, b):
130138
actx = self._array_context
131139

test/test_arraycontext.py

Lines changed: 95 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,84 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type:
197197
# }}}
198198

199199

200+
# {{{ nested containers
201+
202+
@with_container_arithmetic(bcast_obj_array=False,
203+
eq_comparison=False, rel_comparison=False,
204+
_cls_has_array_context_attr=True)
205+
@dataclass_array_container
206+
@dataclass(frozen=True)
207+
class MyContainer:
208+
name: str
209+
mass: Union[DOFArray, np.ndarray]
210+
momentum: np.ndarray
211+
enthalpy: Union[DOFArray, np.ndarray]
212+
213+
@property
214+
def array_context(self):
215+
if isinstance(self.mass, np.ndarray):
216+
return next(iter(self.mass)).array_context
217+
else:
218+
return self.mass.array_context
219+
220+
221+
@with_container_arithmetic(
222+
bcast_obj_array=False,
223+
bcast_container_types=(DOFArray, np.ndarray),
224+
matmul=True,
225+
rel_comparison=True,
226+
_cls_has_array_context_attr=True)
227+
@dataclass_array_container
228+
@dataclass(frozen=True)
229+
class MyContainerDOFBcast:
230+
name: str
231+
mass: Union[DOFArray, np.ndarray]
232+
momentum: np.ndarray
233+
enthalpy: Union[DOFArray, np.ndarray]
234+
235+
@property
236+
def array_context(self):
237+
if isinstance(self.mass, np.ndarray):
238+
return next(iter(self.mass)).array_context
239+
else:
240+
return self.mass.array_context
241+
242+
243+
def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
244+
from numbers import Number
245+
if isinstance(shapes, (Number, tuple)):
246+
shapes = [shapes]
247+
248+
x = DOFArray(actx, tuple([
249+
actx.from_numpy(randn(shape, np.float64))
250+
for shape in shapes]))
251+
252+
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
253+
dataclass_of_dofs = MyContainer(
254+
name="container",
255+
mass=x,
256+
momentum=make_obj_array([x] * ambient_dim),
257+
enthalpy=x)
258+
259+
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
260+
bcast_dataclass_of_dofs = MyContainerDOFBcast(
261+
name="container",
262+
mass=x,
263+
momentum=make_obj_array([x] * ambient_dim),
264+
enthalpy=x)
265+
266+
ary_dof = x
267+
ary_of_dofs = make_obj_array([x] * ambient_dim)
268+
mat_of_dofs = np.empty((ambient_dim, ambient_dim), dtype=object)
269+
for i in np.ndindex(mat_of_dofs.shape):
270+
mat_of_dofs[i] = x
271+
272+
return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs,
273+
bcast_dataclass_of_dofs)
274+
275+
# }}}
276+
277+
200278
# {{{ assert_close_to_numpy*
201279

202280
def randn(shape, dtype):
@@ -341,6 +419,23 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype):
341419
assert_close_to_numpy(
342420
actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args)
343421

422+
for c in (42.0,) + _get_test_containers(actx):
423+
result = getattr(actx.np, sym_name)(c)
424+
result = actx.thaw(actx.freeze(result))
425+
426+
if sym_name == "zeros_like":
427+
if np.isscalar(result):
428+
assert result == 0.0
429+
else:
430+
assert actx.to_numpy(actx.np.all(actx.np.equal(result, 0.0)))
431+
elif sym_name == "ones_like":
432+
if np.isscalar(result):
433+
assert result == 1.0
434+
else:
435+
assert actx.to_numpy(actx.np.all(actx.np.equal(result, 1.0)))
436+
else:
437+
raise ValueError(f"unknown method: '{sym_name}'")
438+
344439
# }}}
345440

346441

@@ -671,79 +766,6 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec):
671766

672767
# {{{ array container classes for test
673768

674-
@with_container_arithmetic(bcast_obj_array=False,
675-
eq_comparison=False, rel_comparison=False,
676-
_cls_has_array_context_attr=True)
677-
@dataclass_array_container
678-
@dataclass(frozen=True)
679-
class MyContainer:
680-
name: str
681-
mass: Union[DOFArray, np.ndarray]
682-
momentum: np.ndarray
683-
enthalpy: Union[DOFArray, np.ndarray]
684-
685-
@property
686-
def array_context(self):
687-
if isinstance(self.mass, np.ndarray):
688-
return next(iter(self.mass)).array_context
689-
else:
690-
return self.mass.array_context
691-
692-
693-
@with_container_arithmetic(
694-
bcast_obj_array=False,
695-
bcast_container_types=(DOFArray, np.ndarray),
696-
matmul=True,
697-
rel_comparison=True,
698-
_cls_has_array_context_attr=True)
699-
@dataclass_array_container
700-
@dataclass(frozen=True)
701-
class MyContainerDOFBcast:
702-
name: str
703-
mass: Union[DOFArray, np.ndarray]
704-
momentum: np.ndarray
705-
enthalpy: Union[DOFArray, np.ndarray]
706-
707-
@property
708-
def array_context(self):
709-
if isinstance(self.mass, np.ndarray):
710-
return next(iter(self.mass)).array_context
711-
else:
712-
return self.mass.array_context
713-
714-
715-
def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
716-
from numbers import Number
717-
if isinstance(shapes, (Number, tuple)):
718-
shapes = [shapes]
719-
720-
x = DOFArray(actx, tuple([
721-
actx.from_numpy(randn(shape, np.float64))
722-
for shape in shapes]))
723-
724-
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
725-
dataclass_of_dofs = MyContainer(
726-
name="container",
727-
mass=x,
728-
momentum=make_obj_array([x] * ambient_dim),
729-
enthalpy=x)
730-
731-
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
732-
bcast_dataclass_of_dofs = MyContainerDOFBcast(
733-
name="container",
734-
mass=x,
735-
momentum=make_obj_array([x] * ambient_dim),
736-
enthalpy=x)
737-
738-
ary_dof = x
739-
ary_of_dofs = make_obj_array([x] * ambient_dim)
740-
mat_of_dofs = np.empty((ambient_dim, ambient_dim), dtype=object)
741-
for i in np.ndindex(mat_of_dofs.shape):
742-
mat_of_dofs[i] = x
743-
744-
return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs,
745-
bcast_dataclass_of_dofs)
746-
747769

748770
def test_container_scalar_map(actx_factory):
749771
actx = actx_factory()

0 commit comments

Comments
 (0)