2929
3030
3131import enum
32- from typing import TYPE_CHECKING , Any
32+ from typing import TYPE_CHECKING , Any , cast
3333
3434import numpy as np
3535
5151if TYPE_CHECKING :
5252 from collections .abc import Sequence
5353
54+ from pyopencl .array import Array
55+
5456
5557# {{{ elementwise kernel code generator
5658
@@ -269,7 +271,7 @@ def get_kernel(self, use_range: bool):
269271
270272 return knl , arg_descrs
271273
272- def __call__ (self , * args , ** kwargs ) -> cl .Event :
274+ def __call__ (self , * args , ** kwargs : Any ) -> cl .Event :
273275 """
274276 Invoke the generated scalar kernel.
275277
@@ -281,7 +283,7 @@ def __call__(self, *args, **kwargs) -> cl.Event:
281283 range_ = kwargs .pop ("range" , None )
282284 slice_ = kwargs .pop ("slice" , None )
283285 capture_as = kwargs .pop ("capture_as" , None )
284- queue = kwargs .pop ("queue" , None )
286+ queue : cl . CommandQueue | None = kwargs .pop ("queue" , None )
285287 wait_for = kwargs .pop ("wait_for" , None )
286288
287289 if kwargs :
@@ -298,14 +300,14 @@ def __call__(self, *args, **kwargs) -> cl.Event:
298300
299301 # {{{ assemble arg array
300302
301- repr_vec = None
303+ repr_vec : Array | None = None
302304 invocation_args = []
303305
304306 # non-strict because length arg gets appended below
305307 for arg , arg_descr in zip (args , arg_descrs , strict = False ):
306308 if isinstance (arg_descr , VectorArg ):
307309 if repr_vec is None :
308- repr_vec = arg
310+ repr_vec = cast ( "Array" , arg )
309311
310312 invocation_args .append (arg )
311313 else :
@@ -325,6 +327,8 @@ def __call__(self, *args, **kwargs) -> cl.Event:
325327
326328 range_ = slice (* slice_ .indices (repr_vec .size ))
327329
330+ assert queue is not None
331+
328332 max_wg_size = kernel .get_work_group_info (
329333 cl .kernel_work_group_info .WORK_GROUP_SIZE ,
330334 queue .device )
0 commit comments