2727THE SOFTWARE.
2828"""
2929
30- from typing import Optional , Tuple , Union , Sequence , Dict , List
30+ from typing import Any , Optional , Tuple , Union , Sequence , Dict , List
31+ from abc import ABC , abstractmethod
32+
33+ import numpy as np
34+
3135from pytato .array import ShapeType , Array , make_index_lambda
3236from pytato .scalar_expr import ScalarExpression , Reduce , INT_CLASSES
3337import pymbolic .primitives as prim
4347.. autofunction:: prod
4448.. autofunction:: all
4549.. autofunction:: any
50+
51+ .. currentmodule:: pytato.reductions
52+
53+ .. autoclass:: ReductionOperation
54+ .. autoclass:: SumReductionOperation
55+ .. autoclass:: ProductReductionOperation
56+ .. autoclass:: MaxReductionOperation
57+ .. autoclass:: MinReductionOperation
58+ .. autoclass:: AllReductionOperation
59+ .. autoclass:: AnyReductionOperation
4660"""
4761
4862# }}}
4963
5064
65+ class _NoValue :
66+ pass
67+
68+
69+ # {{{ reduction operations
70+
71+ class ReductionOperation (ABC ):
72+ """
73+ .. automethod:: neutral_element
74+ .. automethod:: __hash__
75+ .. automethod:: __eq__
76+ """
77+
78+ @abstractmethod
79+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
80+ pass
81+
82+ @abstractmethod
83+ def __hash__ (self ) -> int :
84+ pass
85+
86+ @abstractmethod
87+ def __eq__ (self , other : Any ) -> bool :
88+ pass
89+
90+
91+ class _StatelessReductionOperation (ReductionOperation ):
92+ def __hash__ (self ) -> int :
93+ return hash (type (self ))
94+
95+ def __eq__ (self , other : Any ) -> bool :
96+ return type (self ) is type (other )
97+
98+
99+ class SumReductionOperation (_StatelessReductionOperation ):
100+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
101+ return 0
102+
103+
104+ class ProductReductionOperation (_StatelessReductionOperation ):
105+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
106+ return 1
107+
108+
109+ class MaxReductionOperation (_StatelessReductionOperation ):
110+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
111+ if dtype .kind == "f" :
112+ return dtype .type (float ("-inf" ))
113+ elif dtype .kind == "i" :
114+ return np .iinfo (dtype ).min
115+ else :
116+ raise TypeError (f"unknown neutral element for max and { dtype } " )
117+
118+
119+ class MinReductionOperation (_StatelessReductionOperation ):
120+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
121+ if dtype .kind == "f" :
122+ return dtype .type (float ("inf" ))
123+ elif dtype .kind == "i" :
124+ return np .iinfo (dtype ).max
125+ else :
126+ raise TypeError (f"unknown neutral element for min and { dtype } " )
127+
128+
129+ class AllReductionOperation (_StatelessReductionOperation ):
130+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
131+ return np .bool_ (True )
132+
133+
134+ class AnyReductionOperation (_StatelessReductionOperation ):
135+ def neutral_element (self , dtype : np .dtype [Any ]) -> Any :
136+ return np .bool_ (False )
137+
138+ # }}}
139+
140+
51141# {{{ reductions
52142
53143def _normalize_reduction_axes (
@@ -124,8 +214,9 @@ def _get_reduction_indices_bounds(shape: ShapeType,
124214 return indices , pmap (redn_bounds ) # type: ignore
125215
126216
127- def _make_reduction_lambda (op : str , a : Array ,
128- axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
217+ def _make_reduction_lambda (op : ReductionOperation , a : Array ,
218+ axis : Optional [Union [int , Tuple [int ]]],
219+ initial : Any ) -> Array :
129220 """
130221 Return a :class:`IndexLambda` that performs reduction over the *axis* axes
131222 of *a* with the reduction op *op*.
@@ -137,9 +228,28 @@ def _make_reduction_lambda(op: str, a: Array,
137228 :arg axis: The axes over which the reduction is to be performed. If axis is
138229 *None*, perform reduction over all of *a*'s axes.
139230 """
140- new_shape , axes = _normalize_reduction_axes (a .shape , axis )
231+ new_shape , reduction_axes = _normalize_reduction_axes (a .shape , axis )
141232 del axis
142- indices , redn_bounds = _get_reduction_indices_bounds (a .shape , axes )
233+ indices , redn_bounds = _get_reduction_indices_bounds (a .shape , reduction_axes )
234+
235+ if initial is _NoValue :
236+ for iax in reduction_axes :
237+ shape_iax = a .shape [iax ]
238+
239+ from pytato .utils import are_shape_components_equal
240+ if are_shape_components_equal (shape_iax , 0 ):
241+ raise ValueError (
242+ "zero-size reduction operation with no supplied "
243+ "'initial' value" )
244+
245+ if isinstance (iax , Array ):
246+ raise NotImplementedError (
247+ "cannot statically determine emptiness of "
248+ f"reduction axis { iax } (0-based)" )
249+
250+ elif initial != op .neutral_element (a .dtype ):
251+ raise NotImplementedError ("reduction with 'initial' not equal to the "
252+ "neutral element" )
143253
144254 return make_index_lambda (
145255 Reduce (
@@ -151,52 +261,74 @@ def _make_reduction_lambda(op: str, a: Array,
151261 a .dtype )
152262
153263
154- def sum (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
264+ def sum (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ,
265+ initial : Any = 0 ) -> Array :
155266 """
156267 Sums array *a*'s elements along the *axis* axes.
157268
158269 :arg a: The :class:`pytato.Array` on which to perform the reduction.
159270
160271 :arg axis: The axes along which the elements are to be sum-reduced.
161272 Defaults to all axes of the input array.
273+ :arg initial: The value returned for an empty array, if supplied.
274+ This value also serves as the base value onto which any additional
275+ array entries are accumulated.
162276 """
163- return _make_reduction_lambda ("sum" , a , axis )
277+ return _make_reduction_lambda (SumReductionOperation () , a , axis , initial )
164278
165279
166- def amax (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
280+ def amax (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None , * ,
281+ initial : Any = _NoValue ) -> Array :
167282 """
168283 Returns the max of array *a*'s elements along the *axis* axes.
169284
170285 :arg a: The :class:`pytato.Array` on which to perform the reduction.
171286
172287 :arg axis: The axes along which the elements are to be max-reduced.
173288 Defaults to all axes of the input array.
289+ :arg initial: The value returned for an empty array, if supplied.
290+ This value also serves as the base value onto which any additional
291+ array entries are accumulated.
292+ If not supplied, an :exc:`ValueError` will be raised
293+ if the reduction is empty.
294+ In that case, the reduction size must not be symbolic.
174295 """
175- return _make_reduction_lambda ("max" , a , axis )
296+ return _make_reduction_lambda (MaxReductionOperation () , a , axis , initial )
176297
177298
178- def amin (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
299+ def amin (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ,
300+ initial : Any = _NoValue ) -> Array :
179301 """
180302 Returns the min of array *a*'s elements along the *axis* axes.
181303
182304 :arg a: The :class:`pytato.Array` on which to perform the reduction.
183305
184306 :arg axis: The axes along which the elements are to be min-reduced.
185307 Defaults to all axes of the input array.
308+ :arg initial: The value returned for an empty array, if supplied.
309+ This value also serves as the base value onto which any additional
310+ array entries are accumulated.
311+ If not supplied, an :exc:`ValueError` will be raised
312+ if the reduction is empty.
313+ In that case, the reduction size must not be symbolic.
186314 """
187- return _make_reduction_lambda ("min" , a , axis )
315+ return _make_reduction_lambda (MinReductionOperation () , a , axis , initial )
188316
189317
190- def prod (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
318+ def prod (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ,
319+ initial : Any = 1 ) -> Array :
191320 """
192321 Returns the product of array *a*'s elements along the *axis* axes.
193322
194323 :arg a: The :class:`pytato.Array` on which to perform the reduction.
195324
196325 :arg axis: The axes along which the elements are to be product-reduced.
197326 Defaults to all axes of the input array.
327+ :arg initial: The value returned for an empty array, if supplied.
328+ This value also serves as the base value onto which any additional
329+ array entries are accumulated.
198330 """
199- return _make_reduction_lambda ("product" , a , axis )
331+ return _make_reduction_lambda (ProductReductionOperation () , a , axis , initial )
200332
201333
202334def all (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
@@ -208,7 +340,7 @@ def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
208340 :arg axis: The axes along which the elements are to be product-reduced.
209341 Defaults to all axes of the input array.
210342 """
211- return _make_reduction_lambda ("all" , a , axis )
343+ return _make_reduction_lambda (AllReductionOperation () , a , axis , initial = True )
212344
213345
214346def any (a : Array , axis : Optional [Union [int , Tuple [int ]]] = None ) -> Array :
@@ -220,7 +352,7 @@ def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
220352 :arg axis: The axes along which the elements are to be product-reduced.
221353 Defaults to all axes of the input array.
222354 """
223- return _make_reduction_lambda ("any" , a , axis )
355+ return _make_reduction_lambda (AnyReductionOperation () , a , axis , initial = False )
224356
225357# }}}
226358
0 commit comments