Skip to content

Commit d858949

Browse files
alexfiklinducer
authored andcommitted
feat(typing): improve types in arraycontext.loopy
1 parent 438a35e commit d858949

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

arraycontext/loopy.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656
THE SOFTWARE.
5757
"""
5858

59-
from typing import TYPE_CHECKING, ClassVar
59+
from abc import ABC
60+
from typing import TYPE_CHECKING, ClassVar, cast
6061

6162
import numpy as np
6263

@@ -83,6 +84,8 @@
8384
from loopy.kernel.instruction import InstructionBase
8485
from pytools.tag import ToTagSetConvertible
8586

87+
from arraycontext import ArrayContext
88+
from arraycontext.typing import ArrayOrScalar, ScalarLike
8689

8790
# {{{ loopy
8891

@@ -116,7 +119,7 @@ def make_loopy_program(
116119
tags=tags)
117120

118121

119-
def get_default_entrypoint(t_unit):
122+
def get_default_entrypoint(t_unit: lp.TranslationUnit) -> lp.LoopKernel:
120123
try:
121124
# main and "kernel callables" branch
122125
return t_unit.default_entrypoint
@@ -128,9 +131,11 @@ def get_default_entrypoint(t_unit):
128131
"translation unit") from err
129132

130133

131-
def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
134+
def _get_scalar_func_loopy_program(
135+
actx: ArrayContext, c_name: str, nargs: int, naxes: int,
136+
) -> lp.TranslationUnit:
132137
@memoize_in(actx, _get_scalar_func_loopy_program)
133-
def get(c_name, nargs, naxes):
138+
def get(c_name: str, nargs: int, naxes: int) -> lp.TranslationUnit:
134139
from pymbolic.primitives import Subscript, Variable
135140

136141
var_names = [f"i{i}" for i in range(naxes)]
@@ -170,7 +175,7 @@ def sub(name: str) -> Variable | Subscript:
170175
return get(c_name, nargs, naxes)
171176

172177

173-
class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
178+
class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace, ABC):
174179
_numpy_to_c_arc_functions: ClassVar[Mapping[str, str]] = {
175180
"arcsin": "asin",
176181
"arccos": "acos",
@@ -185,12 +190,15 @@ class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
185190
_c_to_numpy_arc_functions: ClassVar[Mapping[str, str]] = {c_name: numpy_name
186191
for numpy_name, c_name in _numpy_to_c_arc_functions.items()}
187192

188-
def __getattr__(self, name):
189-
def loopy_implemented_elwise_func(*args):
193+
def __getattr__(self, name: str):
194+
def loopy_implemented_elwise_func(*args: ArrayOrScalar) -> ArrayOrScalar:
190195
if all(np.isscalar(ary) for ary in args):
191-
return getattr(
192-
np, self._c_to_numpy_arc_functions.get(name, name)
193-
)(*args)
196+
result = getattr(
197+
np, self._c_to_numpy_arc_functions.get(name, name)
198+
)(*args)
199+
200+
return cast("ScalarLike", result)
201+
194202
actx = self._array_context
195203
prg = _get_scalar_func_loopy_program(actx,
196204
c_name, nargs=len(args), naxes=len(args[0].shape))
@@ -199,8 +207,8 @@ def loopy_implemented_elwise_func(*args):
199207
return outputs["out"]
200208

201209
if name in self._c_to_numpy_arc_functions:
202-
raise RuntimeError(f"'{name}' in ArrayContext.np has been removed. "
203-
f"Use '{self._c_to_numpy_arc_functions[name]}' as in numpy. ")
210+
raise RuntimeError(f"'{name}' in ArrayContext.np has been removed: "
211+
f"use '{self._c_to_numpy_arc_functions[name]}' (as in numpy)")
204212

205213
# normalize to C names anyway
206214
c_name = self._numpy_to_c_arc_functions.get(name, name)

0 commit comments

Comments
 (0)