Skip to content

Commit 8cf2575

Browse files
committed
Introduce ReductionOperation class, accept 'initial' in reductions
1 parent 45bd2a9 commit 8cf2575

File tree

4 files changed

+174
-36
lines changed

4 files changed

+174
-36
lines changed

pytato/codegen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,9 @@ def map_einsum(self, expr: Einsum) -> Array:
334334
args_as_pym_expr[0])
335335

336336
if redn_bounds:
337+
from pytato.reductions import SumReductionOperation
337338
inner_expr = Reduce(inner_expr,
338-
"sum",
339+
SumReductionOperation(),
339340
redn_bounds)
340341

341342
return IndexLambda(expr=inner_expr,

pytato/reductions.py

Lines changed: 147 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
THE SOFTWARE.
2828
"""
2929

30-
from typing import Optional, Tuple, Union, Sequence, Dict, List
30+
from typing import Any, Optional, Tuple, Union, Sequence, Dict, List
31+
from abc import ABC, abstractmethod
32+
33+
import numpy as np
34+
3135
from pytato.array import ShapeType, Array, make_index_lambda
3236
from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES
3337
import pymbolic.primitives as prim
@@ -43,11 +47,97 @@
4347
.. autofunction:: prod
4448
.. autofunction:: all
4549
.. autofunction:: any
50+
51+
.. currentmodule:: pytato.reductions
52+
53+
.. autoclass:: ReductionOperation
54+
.. autoclass:: SumReductionOperation
55+
.. autoclass:: ProductReductionOperation
56+
.. autoclass:: MaxReductionOperation
57+
.. autoclass:: MinReductionOperation
58+
.. autoclass:: AllReductionOperation
59+
.. autoclass:: AnyReductionOperation
4660
"""
4761

4862
# }}}
4963

5064

65+
class _NoValue:
66+
pass
67+
68+
69+
# {{{ reduction operations
70+
71+
class ReductionOperation(ABC):
72+
"""
73+
.. automethod:: neutral_element
74+
.. automethod:: __hash__
75+
.. automethod:: __eq__
76+
"""
77+
78+
@abstractmethod
79+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
80+
pass
81+
82+
@abstractmethod
83+
def __hash__(self) -> int:
84+
pass
85+
86+
@abstractmethod
87+
def __eq__(self, other: Any) -> bool:
88+
pass
89+
90+
91+
class _StatelessReductionOperation(ReductionOperation):
92+
def __hash__(self) -> int:
93+
return hash(type(self))
94+
95+
def __eq__(self, other: Any) -> bool:
96+
return type(self) is type(other)
97+
98+
99+
class SumReductionOperation(_StatelessReductionOperation):
100+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
101+
return 0
102+
103+
104+
class ProductReductionOperation(_StatelessReductionOperation):
105+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
106+
return 1
107+
108+
109+
class MaxReductionOperation(_StatelessReductionOperation):
110+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
111+
if dtype.kind == "f":
112+
return dtype.type(float("-inf"))
113+
elif dtype.kind == "i":
114+
return np.iinfo(dtype).min
115+
else:
116+
raise TypeError(f"unknown neutral element for max and {dtype}")
117+
118+
119+
class MinReductionOperation(_StatelessReductionOperation):
120+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
121+
if dtype.kind == "f":
122+
return dtype.type(float("inf"))
123+
elif dtype.kind == "i":
124+
return np.iinfo(dtype).max
125+
else:
126+
raise TypeError(f"unknown neutral element for min and {dtype}")
127+
128+
129+
class AllReductionOperation(_StatelessReductionOperation):
130+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
131+
return np.bool_(True)
132+
133+
134+
class AnyReductionOperation(_StatelessReductionOperation):
135+
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
136+
return np.bool_(False)
137+
138+
# }}}
139+
140+
51141
# {{{ reductions
52142

53143
def _normalize_reduction_axes(
@@ -124,8 +214,9 @@ def _get_reduction_indices_bounds(shape: ShapeType,
124214
return indices, pmap(redn_bounds) # type: ignore
125215

126216

127-
def _make_reduction_lambda(op: str, a: Array,
128-
axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
217+
def _make_reduction_lambda(op: ReductionOperation, a: Array,
218+
axis: Optional[Union[int, Tuple[int]]],
219+
initial: Any) -> Array:
129220
"""
130221
Return a :class:`IndexLambda` that performs reduction over the *axis* axes
131222
of *a* with the reduction op *op*.
@@ -137,9 +228,28 @@ def _make_reduction_lambda(op: str, a: Array,
137228
:arg axis: The axes over which the reduction is to be performed. If axis is
138229
*None*, perform reduction over all of *a*'s axes.
139230
"""
140-
new_shape, axes = _normalize_reduction_axes(a.shape, axis)
231+
new_shape, reduction_axes = _normalize_reduction_axes(a.shape, axis)
141232
del axis
142-
indices, redn_bounds = _get_reduction_indices_bounds(a.shape, axes)
233+
indices, redn_bounds = _get_reduction_indices_bounds(a.shape, reduction_axes)
234+
235+
if initial is _NoValue:
236+
for iax in reduction_axes:
237+
shape_iax = a.shape[iax]
238+
239+
from pytato.utils import are_shape_components_equal
240+
if are_shape_components_equal(shape_iax, 0):
241+
raise ValueError(
242+
"zero-size reduction operation with no supplied "
243+
"'initial' value")
244+
245+
if isinstance(iax, Array):
246+
raise NotImplementedError(
247+
"cannot statically determine emptiness of "
248+
f"reduction axis {iax} (0-based)")
249+
250+
elif initial != op.neutral_element(a.dtype):
251+
raise NotImplementedError("reduction with 'initial' not equal to the "
252+
"neutral element")
143253

144254
return make_index_lambda(
145255
Reduce(
@@ -151,52 +261,74 @@ def _make_reduction_lambda(op: str, a: Array,
151261
a.dtype)
152262

153263

154-
def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
264+
def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
265+
initial: Any = 0) -> Array:
155266
"""
156267
Sums array *a*'s elements along the *axis* axes.
157268
158269
:arg a: The :class:`pytato.Array` on which to perform the reduction.
159270
160271
:arg axis: The axes along which the elements are to be sum-reduced.
161272
Defaults to all axes of the input array.
273+
:arg initial: The value returned for an empty array, if supplied.
274+
This value also serves as the base value onto which any additional
275+
array entries are accumulated.
162276
"""
163-
return _make_reduction_lambda("sum", a, axis)
277+
return _make_reduction_lambda(SumReductionOperation(), a, axis, initial)
164278

165279

166-
def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
280+
def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, *,
281+
initial: Any = _NoValue) -> Array:
167282
"""
168283
Returns the max of array *a*'s elements along the *axis* axes.
169284
170285
:arg a: The :class:`pytato.Array` on which to perform the reduction.
171286
172287
:arg axis: The axes along which the elements are to be max-reduced.
173288
Defaults to all axes of the input array.
289+
:arg initial: The value returned for an empty array, if supplied.
290+
This value also serves as the base value onto which any additional
291+
array entries are accumulated.
292+
If not supplied, an :exc:`ValueError` will be raised
293+
if the reduction is empty.
294+
In that case, the reduction size must not be symbolic.
174295
"""
175-
return _make_reduction_lambda("max", a, axis)
296+
return _make_reduction_lambda(MaxReductionOperation(), a, axis, initial)
176297

177298

178-
def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
299+
def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
300+
initial: Any = _NoValue) -> Array:
179301
"""
180302
Returns the min of array *a*'s elements along the *axis* axes.
181303
182304
:arg a: The :class:`pytato.Array` on which to perform the reduction.
183305
184306
:arg axis: The axes along which the elements are to be min-reduced.
185307
Defaults to all axes of the input array.
308+
:arg initial: The value returned for an empty array, if supplied.
309+
This value also serves as the base value onto which any additional
310+
array entries are accumulated.
311+
If not supplied, an :exc:`ValueError` will be raised
312+
if the reduction is empty.
313+
In that case, the reduction size must not be symbolic.
186314
"""
187-
return _make_reduction_lambda("min", a, axis)
315+
return _make_reduction_lambda(MinReductionOperation(), a, axis, initial)
188316

189317

190-
def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
318+
def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
319+
initial: Any = 1) -> Array:
191320
"""
192321
Returns the product of array *a*'s elements along the *axis* axes.
193322
194323
:arg a: The :class:`pytato.Array` on which to perform the reduction.
195324
196325
:arg axis: The axes along which the elements are to be product-reduced.
197326
Defaults to all axes of the input array.
327+
:arg initial: The value returned for an empty array, if supplied.
328+
This value also serves as the base value onto which any additional
329+
array entries are accumulated.
198330
"""
199-
return _make_reduction_lambda("product", a, axis)
331+
return _make_reduction_lambda(ProductReductionOperation(), a, axis, initial)
200332

201333

202334
def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
@@ -208,7 +340,7 @@ def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
208340
:arg axis: The axes along which the elements are to be product-reduced.
209341
Defaults to all axes of the input array.
210342
"""
211-
return _make_reduction_lambda("all", a, axis)
343+
return _make_reduction_lambda(AllReductionOperation(), a, axis, initial=True)
212344

213345

214346
def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
@@ -220,7 +352,7 @@ def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
220352
:arg axis: The axes along which the elements are to be product-reduced.
221353
Defaults to all axes of the input array.
222354
"""
223-
return _make_reduction_lambda("any", a, axis)
355+
return _make_reduction_lambda(AnyReductionOperation(), a, axis, initial=False)
224356

225357
# }}}
226358

pytato/scalar_expr.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
"""
2626

2727
from numbers import Number
28-
from typing import Any, Union, Mapping, FrozenSet, Set, Tuple, Optional
28+
from typing import (
29+
Any, Union, Mapping, FrozenSet, Set, Tuple, Optional, TYPE_CHECKING)
2930

3031
from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as
3132
IdentityMapperBase)
@@ -44,6 +45,10 @@
4445
import numpy as np
4546
import re
4647

48+
if TYPE_CHECKING:
49+
from pytato.reductions import ReductionOperation
50+
51+
4752
__doc__ = """
4853
.. currentmodule:: pytato.scalar_expr
4954
@@ -232,21 +237,20 @@ class Reduce(ExpressionBase):
232237
233238
.. attribute:: op
234239
235-
One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``,``"all"``, ``"any"``.
240+
A :class:`pytato.reductions.ReductionOperation`.
236241
237242
.. attribute:: bounds
238243
239244
A mapping from reduction inames to tuples ``(lower_bound, upper_bound)``
240245
identifying half-open bounds intervals. Must be hashable.
241246
"""
242247
inner_expr: ScalarExpression
243-
op: str
248+
op: ReductionOperation
244249
bounds: Mapping[str, Tuple[ScalarExpression, ScalarExpression]]
245250

246-
def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None:
251+
def __init__(self, inner_expr: ScalarExpression,
252+
op: ReductionOperation, bounds: Any) -> None:
247253
self.inner_expr = inner_expr
248-
if op not in {"sum", "product", "max", "min", "all", "any"}:
249-
raise ValueError(f"unsupported op: {op}")
250254
self.op = op
251255
self.bounds = bounds
252256

@@ -256,7 +260,7 @@ def __hash__(self) -> int:
256260
tuple(self.bounds.keys()),
257261
tuple(self.bounds.values())))
258262

259-
def __getinitargs__(self) -> Tuple[ScalarExpression, str, Any]:
263+
def __getinitargs__(self) -> Tuple[ScalarExpression, ReductionOperation, Any]:
260264
return (self.inner_expr, self.op, self.bounds)
261265

262266
mapper_method = "map_reduce"

pytato/target/loopy/codegen.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pytato.loopy import LoopyCall
4949
from pytato.tags import ImplStored, _BaseNameTag, Named, PrefixNamed
5050
from pytools.tag import Tag
51+
import pytato.reductions as red
5152

5253
# set in doc/conf.py
5354
if getattr(sys, "PYTATO_BUILDING_SPHINX_DOCS", False):
@@ -537,13 +538,13 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef":
537538
REDUCTION_INDEX_RE = re.compile("_r(0|([1-9][0-9]*))")
538539

539540
# Maps Pytato reduction types to the corresponding Loopy reduction types.
540-
PYTATO_REDUCTION_TO_LOOPY_REDUCTION = {
541-
"sum": "sum",
542-
"product": "product",
543-
"max": "max",
544-
"min": "min",
545-
"all": "all",
546-
"any": "any",
541+
PYTATO_REDUCTION_TO_LOOPY_REDUCTION: Mapping[Type[red.ReductionOperation], str] = {
542+
red.SumReductionOperation: "sum",
543+
red.ProductReductionOperation: "product",
544+
red.MaxReductionOperation: "max",
545+
red.MinReductionOperation: "min",
546+
red.AllReductionOperation: "all",
547+
red.AnyReductionOperation: "any",
547548
}
548549

549550

@@ -620,8 +621,13 @@ def map_reduce(self, expr: scalar_expr.Reduce,
620621
from loopy.symbolic import Reduction as LoopyReduction
621622
state = prstnt_ctx.state
622623

624+
try:
625+
loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(expr.op)]
626+
except KeyError:
627+
raise NotImplementedError(expr.op)
628+
623629
unique_names_mapping = {
624-
old_name: state.var_name_gen(f"_pt_{expr.op}" + old_name)
630+
old_name: state.var_name_gen(f"_pt_{loopy_redn}" + old_name)
625631
for old_name in expr.bounds}
626632

627633
inner_expr = loopy_substitute(expr.inner_expr,
@@ -633,11 +639,6 @@ def map_reduce(self, expr: scalar_expr.Reduce,
633639
inner_expr = self.rec(inner_expr, prstnt_ctx,
634640
local_ctx.copy(reduction_bounds=new_bounds))
635641

636-
try:
637-
loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[expr.op]
638-
except KeyError:
639-
raise NotImplementedError(expr.op)
640-
641642
inner_expr = LoopyReduction(loopy_redn,
642643
tuple(unique_names_mapping.values()),
643644
inner_expr)

0 commit comments

Comments
 (0)