Skip to content

Commit f008dcc

Browse files
committed
Add some types in elementwise
1 parent c29f294 commit f008dcc

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

pyopencl/elementwise.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
import enum
32-
from typing import TYPE_CHECKING, Any
32+
from typing import TYPE_CHECKING, Any, cast
3333

3434
import numpy as np
3535

@@ -51,6 +51,8 @@
5151
if TYPE_CHECKING:
5252
from collections.abc import Sequence
5353

54+
from pyopencl.array import Array
55+
5456

5557
# {{{ elementwise kernel code generator
5658

@@ -269,7 +271,7 @@ def get_kernel(self, use_range: bool):
269271

270272
return knl, arg_descrs
271273

272-
def __call__(self, *args, **kwargs) -> cl.Event:
274+
def __call__(self, *args, **kwargs: Any) -> cl.Event:
273275
"""
274276
Invoke the generated scalar kernel.
275277
@@ -281,7 +283,7 @@ def __call__(self, *args, **kwargs) -> cl.Event:
281283
range_ = kwargs.pop("range", None)
282284
slice_ = kwargs.pop("slice", None)
283285
capture_as = kwargs.pop("capture_as", None)
284-
queue = kwargs.pop("queue", None)
286+
queue: cl.CommandQueue | None = kwargs.pop("queue", None)
285287
wait_for = kwargs.pop("wait_for", None)
286288

287289
if kwargs:
@@ -298,14 +300,14 @@ def __call__(self, *args, **kwargs) -> cl.Event:
298300

299301
# {{{ assemble arg array
300302

301-
repr_vec = None
303+
repr_vec: Array | None = None
302304
invocation_args = []
303305

304306
# non-strict because length arg gets appended below
305307
for arg, arg_descr in zip(args, arg_descrs, strict=False):
306308
if isinstance(arg_descr, VectorArg):
307309
if repr_vec is None:
308-
repr_vec = arg
310+
repr_vec = cast("Array", arg)
309311

310312
invocation_args.append(arg)
311313
else:
@@ -325,6 +327,8 @@ def __call__(self, *args, **kwargs) -> cl.Event:
325327

326328
range_ = slice(*slice_.indices(repr_vec.size))
327329

330+
assert queue is not None
331+
328332
max_wg_size = kernel.get_work_group_info(
329333
cl.kernel_work_group_info.WORK_GROUP_SIZE,
330334
queue.device)

0 commit comments

Comments
 (0)