Skip to content

Commit ff1cd0c

Browse files
alexfiklinducer
authored andcommitted
rearrange jax.fake_numpy to match other contexts
1 parent be1429c commit ff1cd0c

File tree

1 file changed

+82
-43
lines changed

1 file changed

+82
-43
lines changed

arraycontext/impl/jax/fake_numpy.py

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -50,40 +50,81 @@ def _get_fake_numpy_linalg_namespace(self):
5050
def __getattr__(self, name):
5151
return partial(rec_multimap_array_container, getattr(jnp, name))
5252

53+
# NOTE: the order of these follows the order in numpy docs
54+
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
55+
56+
# {{{ array creation routines
57+
58+
def ones_like(self, ary):
59+
return self.full_like(ary, 1)
60+
61+
def full_like(self, ary, fill_value):
62+
def _full_like(subary):
63+
return jnp.full_like(ary, fill_value)
64+
65+
return self._new_like(ary, _full_like)
66+
67+
# }}}
68+
69+
# {{{ array manipulation routies
70+
5371
def reshape(self, a, newshape, order="C"):
5472
return rec_map_array_container(
5573
lambda ary: jnp.reshape(ary, newshape, order=order),
5674
a)
5775

58-
def transpose(self, a, axes=None):
59-
return rec_multimap_array_container(jnp.transpose, a, axes)
76+
def ravel(self, a, order="C"):
77+
"""
78+
.. warning::
6079
61-
def concatenate(self, arrays, axis=0):
62-
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
80+
Since :func:`jax.numpy.reshape` does not support orders `A`` and
81+
``K``, in such cases we fallback to using ``order = C``.
82+
"""
83+
if order in "AK":
84+
from warnings import warn
85+
warn(f"ravel with order='{order}' not supported by JAX,"
86+
" using order=C.")
87+
order = "C"
6388

64-
def where(self, criterion, then, else_):
65-
return rec_multimap_array_container(jnp.where, criterion, then, else_)
89+
return rec_map_array_container(
90+
lambda subary: jnp.ravel(subary, order=order), a)
6691

67-
def sum(self, a, axis=None, dtype=None):
68-
return rec_map_reduce_array_container(sum,
69-
partial(jnp.sum,
70-
axis=axis,
71-
dtype=dtype),
72-
a)
92+
def transpose(self, a, axes=None):
93+
return rec_multimap_array_container(jnp.transpose, a, axes)
7394

74-
def min(self, a, axis=None):
75-
return rec_map_reduce_array_container(
76-
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
95+
def broadcast_to(self, array, shape):
96+
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
7797

78-
def max(self, a, axis=None):
79-
return rec_map_reduce_array_container(
80-
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
98+
def concatenate(self, arrays, axis=0):
99+
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
81100

82101
def stack(self, arrays, axis=0):
83102
return rec_multimap_array_container(
84103
lambda *args: jnp.stack(arrays=args, axis=axis),
85104
*arrays)
86105

106+
# }}}
107+
108+
# {{{ linear algebra
109+
110+
def vdot(self, x, y, dtype=None):
111+
from arraycontext import rec_multimap_reduce_array_container
112+
113+
def _rec_vdot(ary1, ary2):
114+
if dtype not in [None, numpy.find_common_type((ary1.dtype,
115+
ary2.dtype),
116+
())]:
117+
raise NotImplementedError(f"{type(self)} cannot take dtype in"
118+
" vdot.")
119+
120+
return jnp.vdot(ary1, ary2)
121+
122+
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
123+
124+
# }}}
125+
126+
# {{{ logic functions
127+
87128
def array_equal(self, a, b):
88129
actx = self._array_context
89130

@@ -109,35 +150,33 @@ def rec_equal(x, y):
109150

110151
return rec_equal(a, b)
111152

112-
def ravel(self, a, order="C"):
113-
"""
114-
.. warning::
153+
# }}}
115154

116-
Since :func:`jax.numpy.reshape` does not support orders `A`` and
117-
``K``, in such cases we fallback to using ``order = C``.
118-
"""
119-
if order in "AK":
120-
from warnings import warn
121-
warn(f"ravel with order='{order}' not supported by JAX,"
122-
" using order=C.")
123-
order = "C"
155+
# {{{ mathematical functions
156+
157+
def sum(self, a, axis=None, dtype=None):
158+
return rec_map_reduce_array_container(
159+
sum,
160+
partial(jnp.sum, axis=axis, dtype=dtype),
161+
a)
124162

125-
return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order),
126-
a)
163+
def amin(self, a, axis=None):
164+
return rec_map_reduce_array_container(
165+
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
127166

128-
def vdot(self, x, y, dtype=None):
129-
from arraycontext import rec_multimap_reduce_array_container
167+
min = amin
130168

131-
def _rec_vdot(ary1, ary2):
132-
if dtype not in [None, numpy.find_common_type((ary1.dtype,
133-
ary2.dtype),
134-
())]:
135-
raise NotImplementedError(f"{type(self)} cannot take dtype in"
136-
" vdot.")
169+
def amax(self, a, axis=None):
170+
return rec_map_reduce_array_container(
171+
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
137172

138-
return jnp.vdot(ary1, ary2)
173+
max = amax
139174

140-
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
175+
# }}}
141176

142-
def broadcast_to(self, array, shape):
143-
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
177+
# {{{ sorting, searching and counting
178+
179+
def where(self, criterion, then, else_):
180+
return rec_multimap_array_container(jnp.where, criterion, then, else_)
181+
182+
# }}}

0 commit comments

Comments
 (0)