5656THE SOFTWARE.
5757"""
5858
59- from typing import TYPE_CHECKING , ClassVar
59+ from abc import ABC
60+ from typing import TYPE_CHECKING , ClassVar , cast
6061
6162import numpy as np
6263
8384 from loopy .kernel .instruction import InstructionBase
8485 from pytools .tag import ToTagSetConvertible
8586
87+ from arraycontext import ArrayContext
88+ from arraycontext .typing import ArrayOrScalar , ScalarLike
8689
8790# {{{ loopy
8891
@@ -116,7 +119,7 @@ def make_loopy_program(
116119 tags = tags )
117120
118121
119- def get_default_entrypoint (t_unit ) :
122+ def get_default_entrypoint (t_unit : lp . TranslationUnit ) -> lp . LoopKernel :
120123 try :
121124 # main and "kernel callables" branch
122125 return t_unit .default_entrypoint
@@ -128,9 +131,11 @@ def get_default_entrypoint(t_unit):
128131 "translation unit" ) from err
129132
130133
131- def _get_scalar_func_loopy_program (actx , c_name , nargs , naxes ):
134+ def _get_scalar_func_loopy_program (
135+ actx : ArrayContext , c_name : str , nargs : int , naxes : int ,
136+ ) -> lp .TranslationUnit :
132137 @memoize_in (actx , _get_scalar_func_loopy_program )
133- def get (c_name , nargs , naxes ) :
138+ def get (c_name : str , nargs : int , naxes : int ) -> lp . TranslationUnit :
134139 from pymbolic .primitives import Subscript , Variable
135140
136141 var_names = [f"i{ i } " for i in range (naxes )]
@@ -170,7 +175,7 @@ def sub(name: str) -> Variable | Subscript:
170175 return get (c_name , nargs , naxes )
171176
172177
173- class LoopyBasedFakeNumpyNamespace (BaseFakeNumpyNamespace ):
178+ class LoopyBasedFakeNumpyNamespace (BaseFakeNumpyNamespace , ABC ):
174179 _numpy_to_c_arc_functions : ClassVar [Mapping [str , str ]] = {
175180 "arcsin" : "asin" ,
176181 "arccos" : "acos" ,
@@ -185,12 +190,15 @@ class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
185190 _c_to_numpy_arc_functions : ClassVar [Mapping [str , str ]] = {c_name : numpy_name
186191 for numpy_name , c_name in _numpy_to_c_arc_functions .items ()}
187192
188- def __getattr__ (self , name ):
189- def loopy_implemented_elwise_func (* args ) :
193+ def __getattr__ (self , name : str ):
194+ def loopy_implemented_elwise_func (* args : ArrayOrScalar ) -> ArrayOrScalar :
190195 if all (np .isscalar (ary ) for ary in args ):
191- return getattr (
192- np , self ._c_to_numpy_arc_functions .get (name , name )
193- )(* args )
196+ result = getattr (
197+ np , self ._c_to_numpy_arc_functions .get (name , name )
198+ )(* args )
199+
200+ return cast ("ScalarLike" , result )
201+
194202 actx = self ._array_context
195203 prg = _get_scalar_func_loopy_program (actx ,
196204 c_name , nargs = len (args ), naxes = len (args [0 ].shape ))
@@ -199,8 +207,8 @@ def loopy_implemented_elwise_func(*args):
199207 return outputs ["out" ]
200208
201209 if name in self ._c_to_numpy_arc_functions :
202- raise RuntimeError (f"'{ name } ' in ArrayContext.np has been removed. "
203- f"Use '{ self ._c_to_numpy_arc_functions [name ]} ' as in numpy. " )
210+ raise RuntimeError (f"'{ name } ' in ArrayContext.np has been removed: "
211+ f"use '{ self ._c_to_numpy_arc_functions [name ]} ' ( as in numpy) " )
204212
205213 # normalize to C names anyway
206214 c_name = self ._numpy_to_c_arc_functions .get (name , name )
0 commit comments