Skip to content

Commit 9e19860

Browse files
alexfiklinducer
authored andcommitted
feat(typing): return object from kernel_get_work_group_info
1 parent f94b778 commit 9e19860

File tree

6 files changed

+22
-16
lines changed

6 files changed

+22
-16
lines changed

pyopencl/_monkeypatch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ def kernel_set_arg_types(self: _cl.Kernel, arg_types):
259259
devs=self.context.devices))
260260

261261

262-
def kernel_get_work_group_info(self: _cl.Kernel, param: int, device: _cl.Device):
262+
def kernel_get_work_group_info(
263+
self: _cl.Kernel,
264+
param: int,
265+
device: _cl.Device) -> object:
263266
try:
264267
wg_info_cache = self._wg_info_cache
265268
except AttributeError:

pyopencl/array.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,10 @@ def kernel_runner(out: Array, *args: P.args, **kwargs: P.kwargs) -> cl.Event:
244244
assert queue is not None
245245

246246
knl = kernel_getter(out, *args, **kwargs)
247-
work_group_info = cast("int", knl.get_work_group_info(
247+
work_group_info = knl.get_work_group_info(
248248
cl.kernel_work_group_info.WORK_GROUP_SIZE,
249-
queue.device))
249+
queue.device)
250+
assert isinstance(work_group_info, int)
250251
gs, ls = out._get_sizes(queue, work_group_info)
251252

252253
knl_args = (out, *args, out.size)
@@ -2706,10 +2707,11 @@ def make_func_for_chunk_size(chunk_size):
27062707
if start_i + chunk_size > vec_count:
27072708
knl = make_func_for_chunk_size(vec_count-start_i)
27082709

2709-
gs, ls = dest_indices._get_sizes(queue,
2710-
knl.get_work_group_info(
2711-
cl.kernel_work_group_info.WORK_GROUP_SIZE,
2712-
queue.device))
2710+
work_group_info = knl.get_work_group_info(
2711+
cl.kernel_work_group_info.WORK_GROUP_SIZE,
2712+
queue.device)
2713+
assert isinstance(work_group_info, int)
2714+
gs, ls = dest_indices._get_sizes(queue, work_group_info)
27132715

27142716
wait_for_this = (
27152717
*wait_for,

pyopencl/characterize/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
"""
2525

2626

27-
from typing import cast
28-
2927
from pytools import memoize
3028

3129
import pyopencl as cl
@@ -70,9 +68,9 @@ def reasonable_work_group_size_multiple(
7068
}
7169
""")
7270
prg.build()
73-
return cast("int", prg.knl.get_work_group_info(
71+
return prg.knl.get_work_group_info(
7472
cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
75-
dev))
73+
dev)
7674

7775

7876
def nv_compute_capability(dev: cl.Device):

pyopencl/elementwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def __call__(self, *args, **kwargs) -> cl.Event:
327327
max_wg_size = kernel.get_work_group_info(
328328
cl.kernel_work_group_info.WORK_GROUP_SIZE,
329329
queue.device)
330+
assert isinstance(max_wg_size, int)
330331

331332
if range_ is not None:
332333
start = range_.start

pyopencl/reduction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def __init__(
329329

330330
dtype_out = self.dtype_out = np.dtype(dtype_out)
331331

332-
max_group_size = None
332+
max_group_size: int | None = None
333333
trip_count = 0
334334

335335
while True:
@@ -342,6 +342,7 @@ def __init__(
342342
kernel_max_wg_size = self.stage_1_inf.kernel.get_work_group_info(
343343
cl.kernel_work_group_info.WORK_GROUP_SIZE,
344344
ctx.devices[0])
345+
assert isinstance(kernel_max_wg_size, int)
345346

346347
if self.stage_1_inf.group_size <= kernel_max_wg_size:
347348
break

pyopencl/scan.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import logging
2727
from abc import ABC, abstractmethod
2828
from dataclasses import dataclass
29-
from typing import Any
29+
from typing import Any, cast
3030

3131
import numpy as np
3232

@@ -1303,9 +1303,10 @@ def _finish_setup_impl(self) -> None:
13031303
# at the desired work group size? Building it is the
13041304
# only way to find out.
13051305
kernel_max_wg_size = min(
1306-
candidate_scan_info.kernel.get_work_group_info(
1307-
cl.kernel_work_group_info.WORK_GROUP_SIZE,
1308-
dev)
1306+
cast("int",
1307+
candidate_scan_info.kernel.get_work_group_info(
1308+
cl.kernel_work_group_info.WORK_GROUP_SIZE,
1309+
dev))
13091310
for dev in self.devices)
13101311

13111312
if candidate_scan_info.wg_size <= kernel_max_wg_size:

0 commit comments

Comments
 (0)