Skip to content

Commit 5c9d57a

Browse files
alexfiklinducer
authored andcommitted
forward actx.empty_like in actx.np.empty_like
1 parent cda94d4 commit 5c9d57a

File tree

4 files changed

+40
-34
lines changed

4 files changed

+40
-34
lines changed

arraycontext/fake_numpy.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,11 @@ def _get_fake_numpy_linalg_namespace(self):
9191
# "interp",
9292
})
9393

94-
def _new_like(self, ary, alloc_like):
95-
if np.isscalar(ary):
96-
# NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
97-
# is best implemented by concrete array contexts, if at all
98-
raise NotImplementedError("operation not implemented for scalars")
99-
100-
if isinstance(ary, np.ndarray) and ary.dtype.char == "O":
101-
# NOTE: we don't want to match numpy semantics on object arrays,
102-
# e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
103-
# FIXME: what about object arrays nested in an ArrayContainer?
104-
raise NotImplementedError("operation not implemented for object arrays")
105-
106-
return rec_map_array_container(alloc_like, ary)
107-
10894
def empty_like(self, ary):
109-
return self._new_like(ary, self._array_context.empty_like)
95+
return self._array_context.empty_like(ary)
11096

11197
def zeros_like(self, ary):
112-
return self._new_like(ary, self._array_context.zeros_like)
98+
return self._array_context.zeros_like(ary)
11399

114100
def conjugate(self, x):
115101
# NOTE: conjugate distributes over object arrays, but it looks for a

arraycontext/impl/jax/fake_numpy.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
"""
2424
from functools import partial, reduce
2525

26+
import numpy as np
27+
import jax.numpy as jnp
28+
2629
from arraycontext.fake_numpy import (
2730
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
2831
)
@@ -31,8 +34,6 @@
3134
rec_map_reduce_array_container,
3235
)
3336
from arraycontext.container import NotAnArrayContainerError, serialize_container
34-
import numpy
35-
import jax.numpy as jnp
3637

3738

3839
class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
@@ -62,7 +63,8 @@ def full_like(self, ary, fill_value):
6263
def _full_like(subary):
6364
return jnp.full_like(subary, fill_value)
6465

65-
return self._new_like(ary, _full_like)
66+
return self._array_context._rec_map_container(
67+
_full_like, ary, default_scalar=fill_value)
6668

6769
# }}}
6870

@@ -111,11 +113,10 @@ def vdot(self, x, y, dtype=None):
111113
from arraycontext import rec_multimap_reduce_array_container
112114

113115
def _rec_vdot(ary1, ary2):
114-
if dtype not in [None, numpy.find_common_type((ary1.dtype,
115-
ary2.dtype),
116-
())]:
117-
raise NotImplementedError(f"{type(self)} cannot take dtype in"
118-
" vdot.")
116+
common_dtype = np.find_common_type((ary1.dtype, ary2.dtype), ())
117+
if dtype not in [None, common_dtype]:
118+
raise NotImplementedError(
119+
f"{type(self).__name__} cannot take dtype in vdot.")
119120

120121
return jnp.vdot(ary1, ary2)
121122

@@ -129,8 +130,8 @@ def array_equal(self, a, b):
129130
actx = self._array_context
130131

131132
# NOTE: not all backends support `bool` properly, so use `int8` instead
132-
true = actx.from_numpy(numpy.int8(True))
133-
false = actx.from_numpy(numpy.int8(False))
133+
true = actx.from_numpy(np.int8(True))
134+
false = actx.from_numpy(np.int8(False))
134135

135136
def rec_equal(x, y):
136137
if type(x) != type(y):

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,24 @@ def ones_like(self, ary):
6767
return self.full_like(ary, 1)
6868

6969
def full_like(self, ary, fill_value):
70+
import arraycontext.impl.pyopencl.taggable_cl_array as tga
71+
7072
def _full_like(subary):
71-
ones = self._array_context.empty_like(subary)
72-
ones.fill(fill_value)
73-
return ones
73+
filled = tga.empty(
74+
self._array_context.queue, subary.shape, subary.dtype,
75+
allocator=self._array_context.allocator,
76+
axes=subary.axes, tags=subary.tags)
77+
filled.fill(fill_value)
78+
return filled
7479

75-
return self._new_like(ary, _full_like)
80+
return self._array_context._rec_map_container(
81+
_full_like, ary, default_scalar=fill_value)
7682

7783
def copy(self, ary):
7884
def _copy(subary):
7985
return subary.copy(queue=self._array_context.queue)
8086

81-
return self._new_like(ary, _copy)
87+
return self._array_context._rec_map_container(_copy, ary)
8288

8389
# }}}
8490

@@ -144,9 +150,15 @@ def vdot(self, x, y, dtype=None):
144150

145151
def all(self, a):
146152
queue = self._array_context.queue
153+
154+
def _all(ary):
155+
if np.isscalar(ary):
156+
return np.int8(all([ary]))
157+
return ary.all(queue=queue)
158+
147159
result = rec_map_reduce_array_container(
148160
partial(reduce, partial(cl_array.minimum, queue=queue)),
149-
lambda subary: subary.all(queue=queue),
161+
_all,
150162
a)
151163

152164
if not self._array_context._force_device_scalars:
@@ -155,9 +167,15 @@ def all(self, a):
155167

156168
def any(self, a):
157169
queue = self._array_context.queue
170+
171+
def _any(ary):
172+
if np.isscalar(ary):
173+
return np.int8(any([ary]))
174+
return ary.any(queue=queue)
175+
158176
result = rec_map_reduce_array_container(
159177
partial(reduce, partial(cl_array.maximum, queue=queue)),
160-
lambda subary: subary.any(queue=queue),
178+
_any,
161179
a)
162180

163181
if not self._array_context._force_device_scalars:

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def full_like(self, ary, fill_value):
8383
def _full_like(subary):
8484
return pt.full(subary.shape, fill_value, subary.dtype)
8585

86-
return self._new_like(ary, _full_like)
86+
return self._array_context._rec_map_container(
87+
_full_like, ary, default_scalar=fill_value)
8788

8889
# }}}
8990

0 commit comments

Comments
 (0)