Skip to content

Commit c29f294

Browse files
committed
Add some types in scan
1 parent b1ab035 commit c29f294

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

pyopencl/scan.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import logging
2727
from abc import ABC, abstractmethod
2828
from dataclasses import dataclass
29-
from typing import Any
29+
from typing import TYPE_CHECKING, Any, cast
3030

3131
import numpy as np
3232

@@ -49,6 +49,10 @@
4949
)
5050

5151

52+
if TYPE_CHECKING:
53+
from collections.abc import Sequence
54+
55+
5256
logger = logging.getLogger(__name__)
5357

5458

@@ -900,7 +904,7 @@ class _BuiltScanKernelInfo:
900904
class _GeneratedFinalUpdateKernelInfo:
901905
source: str
902906
kernel_name: str
903-
scalar_arg_dtypes: list[np.dtype | None]
907+
scalar_arg_dtypes: Sequence[np.dtype | None]
904908
update_wg_size: int
905909

906910
def build(self,
@@ -942,7 +946,7 @@ def __init__(
942946
name_prefix: str = "scan",
943947
options: Any = None,
944948
preamble: str = "",
945-
devices: cl.Device | None = None) -> None:
949+
devices: Sequence[cl.Device] | None = None) -> None:
946950
"""
947951
:arg ctx: a :class:`pyopencl.Context` within which the code
948952
for this scan kernel will be generated.
@@ -1031,7 +1035,8 @@ def __init__(
10311035
if input_fetch_exprs is None:
10321036
input_fetch_exprs = []
10331037

1034-
self.context = ctx
1038+
self.context: cl.Context = ctx
1039+
self.dtype: np.dtype[Any]
10351040
dtype = self.dtype = np.dtype(dtype)
10361041

10371042
if neutral is None:
@@ -1044,43 +1049,43 @@ def __init__(
10441049
if dtype.itemsize % 4 != 0:
10451050
raise TypeError("scan value type must have size divisible by 4 bytes")
10461051

1047-
self.index_dtype = np.dtype(index_dtype)
1052+
self.index_dtype: np.dtype[np.integer] = np.dtype(index_dtype)
10481053
if np.iinfo(self.index_dtype).min >= 0:
10491054
raise TypeError("index_dtype must be signed")
10501055

10511056
if devices is None:
10521057
devices = ctx.devices
1053-
self.devices = devices
1058+
self.devices: Sequence[cl.Device] = devices
10541059
self.options = options
10551060

10561061
from pyopencl.tools import parse_arg_list
1057-
self.parsed_args = parse_arg_list(arguments)
1062+
self.parsed_args: Sequence[DtypedArgument] = parse_arg_list(arguments)
10581063
from pyopencl.tools import VectorArg
1059-
self.first_array_idx = next(
1064+
self.first_array_idx: int = next(
10601065
i for i, arg in enumerate(self.parsed_args)
10611066
if isinstance(arg, VectorArg))
10621067

1063-
self.input_expr = input_expr
1068+
self.input_expr: str = input_expr
10641069

1065-
self.is_segment_start_expr = is_segment_start_expr
1066-
self.is_segmented = is_segment_start_expr is not None
1067-
if self.is_segmented:
1070+
self.is_segment_start_expr: str | None = is_segment_start_expr
1071+
self.is_segmented: bool = is_segment_start_expr is not None
1072+
if is_segment_start_expr is not None:
10681073
is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
10691074

1070-
self.output_statement = output_statement
1075+
self.output_statement: str = output_statement
10711076

10721077
for _name, _arg_name, ife_offset in input_fetch_exprs:
10731078
if ife_offset not in [0, -1]:
10741079
raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
1075-
self.input_fetch_exprs = input_fetch_exprs
1080+
self.input_fetch_exprs: Sequence[tuple[str, str, int]] = input_fetch_exprs
10761081

10771082
arg_dtypes = {}
10781083
arg_ctypes = {}
10791084
for arg in self.parsed_args:
10801085
arg_dtypes[arg.name] = arg.dtype
10811086
arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
10821087

1083-
self.name_prefix = name_prefix
1088+
self.name_prefix: str = name_prefix
10841089

10851090
# {{{ set up shared code dict
10861091

@@ -1128,8 +1133,8 @@ def __init__(
11281133

11291134
# }}}
11301135

1131-
self.use_lookbehind_update = "prev_item" in self.output_statement
1132-
self.store_segment_start_flags = (
1136+
self.use_lookbehind_update: bool = "prev_item" in self.output_statement
1137+
self.store_segment_start_flags: bool = (
11331138
self.is_segmented and self.use_lookbehind_update)
11341139

11351140
self.finish_setup()
@@ -1233,8 +1238,8 @@ def _finish_setup_impl(self) -> None:
12331238
# not sure where these go, but roughly this much seems unavailable.
12341239
avail_local_mem -= 0x400
12351240

1236-
is_cpu = self.devices[0].type & cl.device_type.CPU
1237-
is_gpu = self.devices[0].type & cl.device_type.GPU
1241+
is_cpu = bool(self.devices[0].type & cl.device_type.CPU)
1242+
is_gpu = bool(self.devices[0].type & cl.device_type.GPU)
12381243

12391244
if is_cpu:
12401245
# (about the widest vector a CPU can support, also taking
@@ -1260,7 +1265,7 @@ def _finish_setup_impl(self) -> None:
12601265
# k_group_size should be a power of two because of in-kernel
12611266
# division by that number.
12621267

1263-
solutions = []
1268+
solutions: list[tuple[int, int, int]] = []
12641269
for k_exp in range(0, 9):
12651270
for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
12661271
wg_size_multiples):
@@ -1402,7 +1407,7 @@ def get_local_mem_use(
14021407
for arg in self.parsed_args:
14031408
arg_dtypes[arg.name] = arg.dtype
14041409

1405-
fetch_expr_offsets: dict[str, set] = {}
1410+
fetch_expr_offsets: dict[str, set[int]] = {}
14061411
for _name, arg_name, ife_offset in self.input_fetch_exprs:
14071412
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
14081413

@@ -1428,10 +1433,10 @@ def get_local_mem_use(
14281433
def generate_scan_kernel(
14291434
self,
14301435
max_wg_size: int,
1431-
arguments: list[DtypedArgument],
1436+
arguments: Sequence[DtypedArgument],
14321437
input_expr: str,
14331438
is_segment_start_expr: str | None,
1434-
input_fetch_exprs: list[tuple[str, str, int]],
1439+
input_fetch_exprs: Sequence[tuple[str, str, int]],
14351440
is_first_level: bool,
14361441
store_segment_start_flags: bool,
14371442
k_group_size: int,
@@ -1442,7 +1447,7 @@ def generate_scan_kernel(
14421447
wg_size = _round_down_to_power_of_2(
14431448
min(max_wg_size, 256))
14441449

1445-
kernel_name = self.code_variables["name_prefix"]
1450+
kernel_name = cast("str", self.code_variables["name_prefix"])
14461451
if is_first_level:
14471452
kernel_name += "_lev1"
14481453
else:

0 commit comments

Comments
 (0)