2626import logging
2727from abc import ABC , abstractmethod
2828from dataclasses import dataclass
29- from typing import Any
29+ from typing import TYPE_CHECKING , Any , cast
3030
3131import numpy as np
3232
4949)
5050
5151
52+ if TYPE_CHECKING :
53+ from collections .abc import Sequence
54+
55+
5256logger = logging .getLogger (__name__ )
5357
5458
@@ -900,7 +904,7 @@ class _BuiltScanKernelInfo:
900904class _GeneratedFinalUpdateKernelInfo :
901905 source : str
902906 kernel_name : str
903- scalar_arg_dtypes : list [np .dtype | None ]
907+ scalar_arg_dtypes : Sequence [np .dtype | None ]
904908 update_wg_size : int
905909
906910 def build (self ,
@@ -942,7 +946,7 @@ def __init__(
942946 name_prefix : str = "scan" ,
943947 options : Any = None ,
944948 preamble : str = "" ,
945- devices : cl .Device | None = None ) -> None :
949+ devices : Sequence [ cl .Device ] | None = None ) -> None :
946950 """
947951 :arg ctx: a :class:`pyopencl.Context` within which the code
948952 for this scan kernel will be generated.
@@ -1031,7 +1035,8 @@ def __init__(
10311035 if input_fetch_exprs is None :
10321036 input_fetch_exprs = []
10331037
1034- self .context = ctx
1038+ self .context : cl .Context = ctx
1039+ self .dtype : np .dtype [Any ]
10351040 dtype = self .dtype = np .dtype (dtype )
10361041
10371042 if neutral is None :
@@ -1044,43 +1049,43 @@ def __init__(
10441049 if dtype .itemsize % 4 != 0 :
10451050 raise TypeError ("scan value type must have size divisible by 4 bytes" )
10461051
1047- self .index_dtype = np .dtype (index_dtype )
1052+ self .index_dtype : np . dtype [ np . integer ] = np .dtype (index_dtype )
10481053 if np .iinfo (self .index_dtype ).min >= 0 :
10491054 raise TypeError ("index_dtype must be signed" )
10501055
10511056 if devices is None :
10521057 devices = ctx .devices
1053- self .devices = devices
1058+ self .devices : Sequence [ cl . Device ] = devices
10541059 self .options = options
10551060
10561061 from pyopencl .tools import parse_arg_list
1057- self .parsed_args = parse_arg_list (arguments )
1062+ self .parsed_args : Sequence [ DtypedArgument ] = parse_arg_list (arguments )
10581063 from pyopencl .tools import VectorArg
1059- self .first_array_idx = next (
1064+ self .first_array_idx : int = next (
10601065 i for i , arg in enumerate (self .parsed_args )
10611066 if isinstance (arg , VectorArg ))
10621067
1063- self .input_expr = input_expr
1068+ self .input_expr : str = input_expr
10641069
1065- self .is_segment_start_expr = is_segment_start_expr
1066- self .is_segmented = is_segment_start_expr is not None
1067- if self . is_segmented :
1070+ self .is_segment_start_expr : str | None = is_segment_start_expr
1071+ self .is_segmented : bool = is_segment_start_expr is not None
1072+ if is_segment_start_expr is not None :
10681073 is_segment_start_expr = _process_code_for_macro (is_segment_start_expr )
10691074
1070- self .output_statement = output_statement
1075+ self .output_statement : str = output_statement
10711076
10721077 for _name , _arg_name , ife_offset in input_fetch_exprs :
10731078 if ife_offset not in [0 , - 1 ]:
10741079 raise RuntimeError ("input_fetch_expr offsets must either be 0 or -1" )
1075- self .input_fetch_exprs = input_fetch_exprs
1080+ self .input_fetch_exprs : Sequence [ tuple [ str , str , int ]] = input_fetch_exprs
10761081
10771082 arg_dtypes = {}
10781083 arg_ctypes = {}
10791084 for arg in self .parsed_args :
10801085 arg_dtypes [arg .name ] = arg .dtype
10811086 arg_ctypes [arg .name ] = dtype_to_ctype (arg .dtype )
10821087
1083- self .name_prefix = name_prefix
1088+ self .name_prefix : str = name_prefix
10841089
10851090 # {{{ set up shared code dict
10861091
@@ -1128,8 +1133,8 @@ def __init__(
11281133
11291134 # }}}
11301135
1131- self .use_lookbehind_update = "prev_item" in self .output_statement
1132- self .store_segment_start_flags = (
1136+ self .use_lookbehind_update : bool = "prev_item" in self .output_statement
1137+ self .store_segment_start_flags : bool = (
11331138 self .is_segmented and self .use_lookbehind_update )
11341139
11351140 self .finish_setup ()
@@ -1233,8 +1238,8 @@ def _finish_setup_impl(self) -> None:
12331238 # not sure where these go, but roughly this much seems unavailable.
12341239 avail_local_mem -= 0x400
12351240
1236- is_cpu = self .devices [0 ].type & cl .device_type .CPU
1237- is_gpu = self .devices [0 ].type & cl .device_type .GPU
1241+ is_cpu = bool ( self .devices [0 ].type & cl .device_type .CPU )
1242+ is_gpu = bool ( self .devices [0 ].type & cl .device_type .GPU )
12381243
12391244 if is_cpu :
12401245 # (about the widest vector a CPU can support, also taking
@@ -1260,7 +1265,7 @@ def _finish_setup_impl(self) -> None:
12601265 # k_group_size should be a power of two because of in-kernel
12611266 # division by that number.
12621267
1263- solutions = []
1268+ solutions : list [ tuple [ int , int , int ]] = []
12641269 for k_exp in range (0 , 9 ):
12651270 for wg_size in range (wg_size_multiples , max_scan_wg_size + 1 ,
12661271 wg_size_multiples ):
@@ -1402,7 +1407,7 @@ def get_local_mem_use(
14021407 for arg in self .parsed_args :
14031408 arg_dtypes [arg .name ] = arg .dtype
14041409
1405- fetch_expr_offsets : dict [str , set ] = {}
1410+ fetch_expr_offsets : dict [str , set [ int ] ] = {}
14061411 for _name , arg_name , ife_offset in self .input_fetch_exprs :
14071412 fetch_expr_offsets .setdefault (arg_name , set ()).add (ife_offset )
14081413
@@ -1428,10 +1433,10 @@ def get_local_mem_use(
14281433 def generate_scan_kernel (
14291434 self ,
14301435 max_wg_size : int ,
1431- arguments : list [DtypedArgument ],
1436+ arguments : Sequence [DtypedArgument ],
14321437 input_expr : str ,
14331438 is_segment_start_expr : str | None ,
1434- input_fetch_exprs : list [tuple [str , str , int ]],
1439+ input_fetch_exprs : Sequence [tuple [str , str , int ]],
14351440 is_first_level : bool ,
14361441 store_segment_start_flags : bool ,
14371442 k_group_size : int ,
@@ -1442,7 +1447,7 @@ def generate_scan_kernel(
14421447 wg_size = _round_down_to_power_of_2 (
14431448 min (max_wg_size , 256 ))
14441449
1445- kernel_name = self .code_variables ["name_prefix" ]
1450+ kernel_name = cast ( "str" , self .code_variables ["name_prefix" ])
14461451 if is_first_level :
14471452 kernel_name += "_lev1"
14481453 else :
0 commit comments