@@ -50,40 +50,81 @@ def _get_fake_numpy_linalg_namespace(self):
5050 def __getattr__ (self , name ):
5151 return partial (rec_multimap_array_container , getattr (jnp , name ))
5252
53+ # NOTE: the order of these follows the order in numpy docs
54+ # NOTE: when adding a function here, also add it to `array_context.rst` docs!
55+
56+ # {{{ array creation routines
57+
58+ def ones_like (self , ary ):
59+ return self .full_like (ary , 1 )
60+
61+ def full_like (self , ary , fill_value ):
62+ def _full_like (subary ):
63+ return jnp .full_like (ary , fill_value )
64+
65+ return self ._new_like (ary , _full_like )
66+
67+ # }}}
68+
69+ # {{{ array manipulation routies
70+
5371 def reshape (self , a , newshape , order = "C" ):
5472 return rec_map_array_container (
5573 lambda ary : jnp .reshape (ary , newshape , order = order ),
5674 a )
5775
58- def transpose (self , a , axes = None ):
59- return rec_multimap_array_container (jnp .transpose , a , axes )
76+ def ravel (self , a , order = "C" ):
77+ """
78+ .. warning::
6079
61- def concatenate (self , arrays , axis = 0 ):
62- return rec_multimap_array_container (jnp .concatenate , arrays , axis )
80+ Since :func:`jax.numpy.reshape` does not support orders `A`` and
81+ ``K``, in such cases we fallback to using ``order = C``.
82+ """
83+ if order in "AK" :
84+ from warnings import warn
85+ warn (f"ravel with order='{ order } ' not supported by JAX,"
86+ " using order=C." )
87+ order = "C"
6388
64- def where ( self , criterion , then , else_ ):
65- return rec_multimap_array_container ( jnp .where , criterion , then , else_ )
89+ return rec_map_array_container (
90+ lambda subary : jnp .ravel ( subary , order = order ), a )
6691
67- def sum (self , a , axis = None , dtype = None ):
68- return rec_map_reduce_array_container (sum ,
69- partial (jnp .sum ,
70- axis = axis ,
71- dtype = dtype ),
72- a )
92+ def transpose (self , a , axes = None ):
93+ return rec_multimap_array_container (jnp .transpose , a , axes )
7394
74- def min (self , a , axis = None ):
75- return rec_map_reduce_array_container (
76- partial (reduce , jnp .minimum ), partial (jnp .amin , axis = axis ), a )
95+ def broadcast_to (self , array , shape ):
96+ return rec_map_array_container (partial (jnp .broadcast_to , shape = shape ), array )
7797
78- def max (self , a , axis = None ):
79- return rec_map_reduce_array_container (
80- partial (reduce , jnp .maximum ), partial (jnp .amax , axis = axis ), a )
98+ def concatenate (self , arrays , axis = 0 ):
99+ return rec_multimap_array_container (jnp .concatenate , arrays , axis )
81100
82101 def stack (self , arrays , axis = 0 ):
83102 return rec_multimap_array_container (
84103 lambda * args : jnp .stack (arrays = args , axis = axis ),
85104 * arrays )
86105
106+ # }}}
107+
108+ # {{{ linear algebra
109+
110+ def vdot (self , x , y , dtype = None ):
111+ from arraycontext import rec_multimap_reduce_array_container
112+
113+ def _rec_vdot (ary1 , ary2 ):
114+ if dtype not in [None , numpy .find_common_type ((ary1 .dtype ,
115+ ary2 .dtype ),
116+ ())]:
117+ raise NotImplementedError (f"{ type (self )} cannot take dtype in"
118+ " vdot." )
119+
120+ return jnp .vdot (ary1 , ary2 )
121+
122+ return rec_multimap_reduce_array_container (sum , _rec_vdot , x , y )
123+
124+ # }}}
125+
126+ # {{{ logic functions
127+
87128 def array_equal (self , a , b ):
88129 actx = self ._array_context
89130
@@ -109,35 +150,33 @@ def rec_equal(x, y):
109150
110151 return rec_equal (a , b )
111152
112- def ravel (self , a , order = "C" ):
113- """
114- .. warning::
153+ # }}}
115154
116- Since :func:`jax.numpy.reshape` does not support orders `A`` and
117- ``K``, in such cases we fallback to using ``order = C``.
118- """
119- if order in "AK" :
120- from warnings import warn
121- warn (f"ravel with order='{ order } ' not supported by JAX,"
122- " using order=C." )
123- order = "C"
155+ # {{{ mathematical functions
156+
157+ def sum (self , a , axis = None , dtype = None ):
158+ return rec_map_reduce_array_container (
159+ sum ,
160+ partial (jnp .sum , axis = axis , dtype = dtype ),
161+ a )
124162
125- return rec_map_array_container (lambda subary : jnp .ravel (subary , order = order ),
126- a )
163+ def amin (self , a , axis = None ):
164+ return rec_map_reduce_array_container (
165+ partial (reduce , jnp .minimum ), partial (jnp .amin , axis = axis ), a )
127166
128- def vdot (self , x , y , dtype = None ):
129- from arraycontext import rec_multimap_reduce_array_container
167+ min = amin
130168
131- def _rec_vdot (ary1 , ary2 ):
132- if dtype not in [None , numpy .find_common_type ((ary1 .dtype ,
133- ary2 .dtype ),
134- ())]:
135- raise NotImplementedError (f"{ type (self )} cannot take dtype in"
136- " vdot." )
169+ def amax (self , a , axis = None ):
170+ return rec_map_reduce_array_container (
171+ partial (reduce , jnp .maximum ), partial (jnp .amax , axis = axis ), a )
137172
138- return jnp . vdot ( ary1 , ary2 )
173+ max = amax
139174
140- return rec_multimap_reduce_array_container ( sum , _rec_vdot , x , y )
175+ # }}}
141176
142- def broadcast_to (self , array , shape ):
143- return rec_map_array_container (partial (jnp .broadcast_to , shape = shape ), array )
177+ # {{{ sorting, searching and counting
178+
179+ def where (self , criterion , then , else_ ):
180+ return rec_multimap_array_container (jnp .where , criterion , then , else_ )
181+
182+ # }}}
0 commit comments