Skip to content

Commit 9b5f7c9

Browse files
alexfiklinducer
authored andcommitted
feat(typing): finish typing cltypes.py
1 parent ca4c680 commit 9b5f7c9

File tree

1 file changed

+42
-27
lines changed

1 file changed

+42
-27
lines changed

pyopencl/cltypes.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
"""
2323

2424
import warnings
25-
from typing import Any
25+
from typing import TYPE_CHECKING, Any, cast
2626

2727
import numpy as np
2828

2929
from pyopencl.tools import get_or_register_dtype
3030

3131

32+
if TYPE_CHECKING:
33+
import builtins
34+
from collections.abc import MutableSequence
35+
3236
if __file__.endswith("array.py"):
3337
warnings.warn(
3438
"pyopencl.array.vec is deprecated. Please use pyopencl.cltypes.",
@@ -53,37 +57,41 @@
5357

5458
# {{{ vector types
5559

56-
def _create_vector_types():
60+
def _create_vector_types() -> tuple[
61+
dict[tuple[np.dtype[Any], builtins.int], np.dtype[Any]],
62+
dict[np.dtype[Any], tuple[np.dtype[Any], builtins.int]]]:
5763
mapping = [(k, globals()[k]) for k in
5864
["char", "uchar", "short", "ushort", "int",
5965
"uint", "long", "ulong", "float", "double"]]
6066

61-
def set_global(key, val):
67+
def set_global(key: str, val: np.dtype[Any]) -> None:
6268
globals()[key] = val
6369

64-
vec_types = {}
65-
vec_type_to_scalar_and_count = {}
70+
vec_types: dict[tuple[np.dtype[Any], builtins.int], np.dtype[Any]] = {}
71+
vec_type_to_scalar_and_count: dict[np.dtype[Any],
72+
tuple[np.dtype[Any], builtins.int]] = {}
6673

6774
field_names = ["x", "y", "z", "w"]
6875

6976
counts = [2, 3, 4, 8, 16]
7077

7178
for base_name, base_type in mapping:
7279
for count in counts:
73-
name = "%s%d" % (base_name, count)
74-
75-
titles = field_names[:count]
80+
name = f"{base_name}{count}"
81+
titles = cast("MutableSequence[str | None]", field_names[:count])
7682

7783
padded_count = count
7884
if count == 3:
7985
padded_count = 4
8086

81-
names = ["s%d" % i for i in range(count)]
87+
names = [f"s{i}" for i in range(count)]
8288
while len(names) < padded_count:
83-
names.append("padding%d" % (len(names) - count))
89+
pad = len(names) - count
90+
names.append(f"padding{pad}")
8491

8592
if len(titles) < len(names):
86-
titles.extend((len(names) - len(titles)) * [None])
93+
pad = len(names) - len(titles)
94+
titles.extend([None] * pad)
8795

8896
try:
8997
dtype = np.dtype({
@@ -96,14 +104,16 @@ def set_global(key, val):
96104
for (n, title)
97105
in zip(names, titles, strict=True)])
98106
except TypeError:
99-
dtype = np.dtype([(n, base_type) for (n, title)
100-
in zip(names, titles, strict=True)])
107+
dtype = np.dtype([(n, base_type) for n in names])
101108

109+
assert isinstance(dtype, np.dtype)
102110
get_or_register_dtype(name, dtype)
103-
104111
set_global(name, dtype)
105112

106-
def create_array(dtype, count, padded_count, *args, **kwargs):
113+
def create_array(dtype: np.dtype[Any],
114+
count: int,
115+
padded_count: int,
116+
*args: Any, **kwargs: Any) -> dict[str, Any]:
107117
if len(args) < count:
108118
from warnings import warn
109119
warn("default values for make_xxx are deprecated;"
@@ -116,21 +126,26 @@ def create_array(dtype, count, padded_count, *args, **kwargs):
116126
{"array": np.array,
117127
"padded_args": padded_args,
118128
"dtype": dtype})
119-
for key, val in list(kwargs.items()):
129+
130+
for key, val in kwargs.items():
120131
array[key] = val
132+
121133
return array
122134

123-
set_global("make_" + name, eval(
124-
"lambda *args, **kwargs: create_array(dtype, %i, %i, "
125-
"*args, **kwargs)" % (count, padded_count),
126-
{"create_array": create_array, "dtype": dtype}))
127-
set_global("filled_" + name, eval(
128-
"lambda val: make_%s(*[val]*%i)" % (name, count)))
129-
set_global("zeros_" + name, eval("lambda: filled_%s(0)" % (name)))
130-
set_global("ones_" + name, eval("lambda: filled_%s(1)" % (name)))
131-
132-
vec_types[np.dtype(base_type), count] = dtype
133-
vec_type_to_scalar_and_count[dtype] = np.dtype(base_type), count
135+
set_global(
136+
f"make_{name}",
137+
eval("lambda *args, **kwargs: "
138+
f"create_array(dtype, {count}, {padded_count}, *args, **kwargs)",
139+
{"create_array": create_array, "dtype": dtype}))
140+
set_global(
141+
f"filled_{name}",
142+
eval(f"lambda val: make_{name}(*[val]*{count})"))
143+
set_global(f"zeros_{name}", eval(f"lambda: filled_{name}(0)"))
144+
set_global(f"ones_{name}", eval(f"lambda: filled_{name}(1)"))
145+
146+
base_dtype = np.dtype(base_type)
147+
vec_types[base_dtype, count] = dtype
148+
vec_type_to_scalar_and_count[dtype] = base_dtype, count
134149

135150
return vec_types, vec_type_to_scalar_and_count
136151

0 commit comments

Comments
 (0)