Skip to content

Commit 806a010

Browse files
committed
fake_numpy: sharpen min/max/sum types, deprecate amax/amin
1 parent e722fd3 commit 806a010

File tree

5 files changed

+209
-20
lines changed

5 files changed

+209
-20
lines changed

arraycontext/fake_numpy.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from typing import TYPE_CHECKING, Any, Literal, cast, overload
3434

3535
import numpy as np
36+
from typing_extensions import deprecated
3637

3738
from arraycontext.container import (
3839
NotAnArrayContainerError,
@@ -394,32 +395,69 @@ def where(self,
394395

395396
# {{{ reductions
396397

398+
@overload
397399
def sum(self,
398-
a: ArrayOrContainerOrScalar,
400+
a: ArrayOrContainer,
399401
axis: int | tuple[int, ...] | None = None,
400402
dtype: DTypeLike = None,
401-
) -> ArrayOrScalar: ...
403+
) -> Array: ...
404+
@overload
405+
def sum(self,
406+
a: ScalarLike,
407+
axis: int | tuple[int, ...] | None = None,
408+
dtype: DTypeLike = None,
409+
) -> ScalarLike: ...
402410

403-
def max(self,
411+
def sum(self,
404412
a: ArrayOrContainerOrScalar,
405413
axis: int | tuple[int, ...] | None = None,
414+
dtype: DTypeLike = None,
406415
) -> ArrayOrScalar: ...
407416

417+
@overload
418+
def min(self,
419+
a: ArrayOrContainer,
420+
axis: int | tuple[int, ...] | None = None,
421+
) -> Array: ...
422+
@overload
423+
def min(self,
424+
a: ScalarLike,
425+
axis: int | tuple[int, ...] | None = None,
426+
) -> ScalarLike: ...
427+
408428
def min(self,
409429
a: ArrayOrContainerOrScalar,
410430
axis: int | tuple[int, ...] | None = None,
411431
) -> ArrayOrScalar: ...
412432

413-
def amax(self,
433+
@overload
434+
def max(self,
435+
a: ArrayOrContainer,
436+
axis: int | tuple[int, ...] | None = None,
437+
) -> Array: ...
438+
@overload
439+
def max(self,
440+
a: ScalarLike,
441+
axis: int | tuple[int, ...] | None = None,
442+
) -> ScalarLike: ...
443+
444+
def max(self,
414445
a: ArrayOrContainerOrScalar,
415446
axis: int | tuple[int, ...] | None = None,
416447
) -> ArrayOrScalar: ...
417448

449+
@deprecated("use min instead")
418450
def amin(self,
419451
a: ArrayOrContainerOrScalar,
420452
axis: int | tuple[int, ...] | None = None,
421453
) -> ArrayOrScalar: ...
422454

455+
@deprecated("use max instead")
456+
def amax(self,
457+
a: ArrayOrContainerOrScalar,
458+
axis: int | tuple[int, ...] | None = None,
459+
) -> ArrayOrScalar: ...
460+
423461
def any(self,
424462
a: ArrayOrContainerOrScalar,
425463
) -> ArrayOrScalar: ...

arraycontext/impl/jax/fake_numpy.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
THE SOFTWARE.
2626
"""
2727
from functools import partial, reduce
28-
from typing import TYPE_CHECKING, cast
28+
from typing import TYPE_CHECKING, cast, overload
2929

3030
import numpy as np
3131
from typing_extensions import override
@@ -42,7 +42,7 @@
4242
rec_multimap_array_container,
4343
)
4444
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
45-
from arraycontext.typing import is_scalar_like
45+
from arraycontext.typing import ArrayOrContainer, is_scalar_like
4646

4747

4848
if TYPE_CHECKING:
@@ -205,6 +205,19 @@ def rec_equal(x, y):
205205

206206
# {{{ mathematical functions
207207

208+
@overload
209+
def sum(self,
210+
a: ArrayOrContainer,
211+
axis: int | tuple[int, ...] | None = None,
212+
dtype: DTypeLike = None,
213+
) -> Array: ...
214+
@overload
215+
def sum(self,
216+
a: Scalar,
217+
axis: int | tuple[int, ...] | None = None,
218+
dtype: DTypeLike = None,
219+
) -> Scalar: ...
220+
208221
@override
209222
def sum(self,
210223
a: ArrayOrContainerOrScalar,
@@ -216,6 +229,17 @@ def sum(self,
216229
partial(jnp.sum, axis=axis, dtype=dtype),
217230
a)
218231

232+
@overload
233+
def min(self,
234+
a: ArrayOrContainer,
235+
axis: int | tuple[int, ...] | None = None,
236+
) -> Array: ...
237+
@overload
238+
def min(self,
239+
a: Scalar,
240+
axis: int | tuple[int, ...] | None = None,
241+
) -> Scalar: ...
242+
219243
@override
220244
def min(self,
221245
a: ArrayOrContainerOrScalar,
@@ -224,7 +248,18 @@ def min(self,
224248
return rec_map_reduce_array_container(
225249
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
226250

227-
amin = min
251+
amin = min # pyright: ignore[reportAssignmentType, reportDeprecated]
252+
253+
@overload
254+
def max(self,
255+
a: ArrayOrContainer,
256+
axis: int | tuple[int, ...] | None = None,
257+
) -> Array: ...
258+
@overload
259+
def max(self,
260+
a: Scalar,
261+
axis: int | tuple[int, ...] | None = None,
262+
) -> Scalar: ...
228263

229264
@override
230265
def max(self,
@@ -234,7 +269,7 @@ def max(self,
234269
return rec_map_reduce_array_container(
235270
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
236271

237-
amax = max
272+
amax = max # pyright: ignore[reportDeprecated, reportAssignmentType]
238273

239274
# }}}
240275

arraycontext/impl/numpy/fake_numpy.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
THE SOFTWARE.
2626
"""
2727
from functools import partial, reduce
28-
from typing import TYPE_CHECKING, cast
28+
from typing import TYPE_CHECKING, cast, overload
2929

3030
import numpy as np
3131
from typing_extensions import override
@@ -41,7 +41,7 @@
4141
BaseFakeNumpyLinalgNamespace,
4242
BaseFakeNumpyNamespace,
4343
)
44-
from arraycontext.typing import OrderCF, is_scalar_like
44+
from arraycontext.typing import ArrayOrContainer, OrderCF, is_scalar_like
4545

4646

4747
if TYPE_CHECKING:
@@ -73,6 +73,7 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
7373
"""
7474
A :mod:`numpy` mimic for :class:`NumpyArrayContext`.
7575
"""
76+
@override
7677
def _get_fake_numpy_linalg_namespace(self):
7778
return NumpyFakeNumpyLinalgNamespace(self._array_context)
7879

@@ -95,12 +96,41 @@ def __getattr__(self, name: str):
9596

9697
raise AttributeError(name)
9798

98-
def sum(self, a, axis=None, dtype=None):
99+
@overload
100+
def sum(self,
101+
a: ArrayOrContainer,
102+
axis: int | tuple[int, ...] | None = None,
103+
dtype: DTypeLike = None,
104+
) -> Array: ...
105+
@overload
106+
def sum(self,
107+
a: Scalar,
108+
axis: int | tuple[int, ...] | None = None,
109+
dtype: DTypeLike = None,
110+
) -> Scalar: ...
111+
112+
@override
113+
def sum(self,
114+
a: ArrayOrContainerOrScalar,
115+
axis: int | tuple[int, ...] | None = None,
116+
dtype: DTypeLike = None,
117+
) -> ArrayOrScalar:
99118
return rec_map_reduce_array_container(sum, partial(np.sum,
100119
axis=axis,
101120
dtype=dtype),
102121
a)
103122

123+
@overload
124+
def min(self,
125+
a: ArrayOrContainer,
126+
axis: int | tuple[int, ...] | None = None,
127+
) -> Array: ...
128+
@overload
129+
def min(self,
130+
a: Scalar,
131+
axis: int | tuple[int, ...] | None = None,
132+
) -> Scalar: ...
133+
104134
@override
105135
def min(self,
106136
a: ArrayOrContainerOrScalar,
@@ -109,6 +139,17 @@ def min(self,
109139
return rec_map_reduce_array_container(
110140
partial(reduce, np.minimum), partial(np.amin, axis=axis), a)
111141

142+
@overload
143+
def max(self,
144+
a: ArrayOrContainer,
145+
axis: int | tuple[int, ...] | None = None,
146+
) -> Array: ...
147+
@overload
148+
def max(self,
149+
a: Scalar,
150+
axis: int | tuple[int, ...] | None = None,
151+
) -> Scalar: ...
152+
112153
@override
113154
def max(self,
114155
a: ArrayOrContainerOrScalar,

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import operator
3333
from functools import partial, reduce
34-
from typing import TYPE_CHECKING, cast
34+
from typing import TYPE_CHECKING, cast, overload
3535
from warnings import warn
3636

3737
import numpy as np
@@ -49,7 +49,7 @@
4949
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
5050
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
5151
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
52-
from arraycontext.typing import OrderCF, is_scalar_like
52+
from arraycontext.typing import ArrayOrContainer, OrderCF, ScalarLike, is_scalar_like
5353

5454

5555
if TYPE_CHECKING:
@@ -341,7 +341,25 @@ def inner(ary: ArrayOrScalar) -> ArrayOrScalar:
341341

342342
# {{{ mathematical functions
343343

344-
def sum(self, a, axis=None, dtype=None):
344+
@overload
345+
def sum(self,
346+
a: ArrayOrContainer,
347+
axis: int | tuple[int, ...] | None = None,
348+
dtype: DTypeLike = None,
349+
) -> Array: ...
350+
@overload
351+
def sum(self,
352+
a: ScalarLike,
353+
axis: int | tuple[int, ...] | None = None,
354+
dtype: DTypeLike = None,
355+
) -> ScalarLike: ...
356+
357+
@override
358+
def sum(self,
359+
a: ArrayOrContainerOrScalar,
360+
axis: int | tuple[int, ...] | None = None,
361+
dtype: DTypeLike = None,
362+
) -> ArrayOrScalar:
345363
if isinstance(axis, int):
346364
axis = axis,
347365

@@ -358,6 +376,17 @@ def maximum(self, x, y):
358376
partial(cl_array.maximum, queue=self._array_context.queue),
359377
x, y)
360378

379+
@overload
380+
def max(self,
381+
a: ArrayOrContainer,
382+
axis: int | tuple[int, ...] | None = None,
383+
) -> Array: ...
384+
@overload
385+
def max(self,
386+
a: ScalarLike,
387+
axis: int | tuple[int, ...] | None = None,
388+
) -> ScalarLike: ...
389+
361390
@override
362391
def max(self,
363392
a: ArrayOrContainerOrScalar,
@@ -379,13 +408,24 @@ def _rec_max(ary):
379408
_rec_max,
380409
a)
381410

382-
amax = max
411+
amax = max # pyright: ignore[reportAssignmentType, reportDeprecated]
383412

384413
def minimum(self, x, y):
385414
return rec_multimap_array_container(
386415
partial(cl_array.minimum, queue=self._array_context.queue),
387416
x, y)
388417

418+
@overload
419+
def min(self,
420+
a: ArrayOrContainer,
421+
axis: int | tuple[int, ...] | None = None,
422+
) -> Array: ...
423+
@overload
424+
def min(self,
425+
a: ScalarLike,
426+
axis: int | tuple[int, ...] | None = None,
427+
) -> ScalarLike: ...
428+
389429
@override
390430
def min(self,
391431
a: ArrayOrContainerOrScalar,
@@ -406,7 +446,7 @@ def _rec_min(ary):
406446
_rec_min,
407447
a)
408448

409-
amin = min
449+
amin = min # pyright: ignore[reportAssignmentType, reportDeprecated]
410450

411451
def absolute(self, a):
412452
return self.abs(a)

0 commit comments

Comments
 (0)