2222"""
2323
2424import warnings
25- from typing import Any
25+ from typing import TYPE_CHECKING , Any , cast
2626
2727import numpy as np
2828
2929from pyopencl .tools import get_or_register_dtype
3030
3131
32+ if TYPE_CHECKING :
33+ import builtins
34+ from collections .abc import MutableSequence
35+
3236if __file__ .endswith ("array.py" ):
3337 warnings .warn (
3438 "pyopencl.array.vec is deprecated. Please use pyopencl.cltypes." ,
5357
5458# {{{ vector types
5559
56- def _create_vector_types ():
60+ def _create_vector_types () -> tuple [
61+ dict [tuple [np .dtype [Any ], builtins .int ], np .dtype [Any ]],
62+ dict [np .dtype [Any ], tuple [np .dtype [Any ], builtins .int ]]]:
5763 mapping = [(k , globals ()[k ]) for k in
5864 ["char" , "uchar" , "short" , "ushort" , "int" ,
5965 "uint" , "long" , "ulong" , "float" , "double" ]]
6066
61- def set_global (key , val ) :
67+ def set_global (key : str , val : np . dtype [ Any ]) -> None :
6268 globals ()[key ] = val
6369
64- vec_types = {}
65- vec_type_to_scalar_and_count = {}
70+ vec_types : dict [tuple [np .dtype [Any ], builtins .int ], np .dtype [Any ]] = {}
71+ vec_type_to_scalar_and_count : dict [np .dtype [Any ],
72+ tuple [np .dtype [Any ], builtins .int ]] = {}
6673
6774 field_names = ["x" , "y" , "z" , "w" ]
6875
6976 counts = [2 , 3 , 4 , 8 , 16 ]
7077
7178 for base_name , base_type in mapping :
7279 for count in counts :
73- name = "%s%d" % (base_name , count )
74-
75- titles = field_names [:count ]
80+ name = f"{ base_name } { count } "
81+ titles = cast ("MutableSequence[str | None]" , field_names [:count ])
7682
7783 padded_count = count
7884 if count == 3 :
7985 padded_count = 4
8086
81- names = ["s%d" % i for i in range (count )]
87+ names = [f"s { i } " for i in range (count )]
8288 while len (names ) < padded_count :
83- names .append ("padding%d" % (len (names ) - count ))
89+ pad = len (names ) - count
90+ names .append (f"padding{ pad } " )
8491
8592 if len (titles ) < len (names ):
86- titles .extend ((len (names ) - len (titles )) * [None ])
93+ pad = len (names ) - len (titles )
94+ titles .extend ([None ] * pad )
8795
8896 try :
8997 dtype = np .dtype ({
@@ -96,14 +104,16 @@ def set_global(key, val):
96104 for (n , title )
97105 in zip (names , titles , strict = True )])
98106 except TypeError :
99- dtype = np .dtype ([(n , base_type ) for (n , title )
100- in zip (names , titles , strict = True )])
107+ dtype = np .dtype ([(n , base_type ) for n in names ])
101108
109+ assert isinstance (dtype , np .dtype )
102110 get_or_register_dtype (name , dtype )
103-
104111 set_global (name , dtype )
105112
106- def create_array (dtype , count , padded_count , * args , ** kwargs ):
113+ def create_array (dtype : np .dtype [Any ],
114+ count : int ,
115+ padded_count : int ,
116+ * args : Any , ** kwargs : Any ) -> dict [str , Any ]:
107117 if len (args ) < count :
108118 from warnings import warn
109119 warn ("default values for make_xxx are deprecated;"
@@ -116,21 +126,26 @@ def create_array(dtype, count, padded_count, *args, **kwargs):
116126 {"array" : np .array ,
117127 "padded_args" : padded_args ,
118128 "dtype" : dtype })
119- for key , val in list (kwargs .items ()):
129+
130+ for key , val in kwargs .items ():
120131 array [key ] = val
132+
121133 return array
122134
123- set_global ("make_" + name , eval (
124- "lambda *args, **kwargs: create_array(dtype, %i, %i, "
125- "*args, **kwargs)" % (count , padded_count ),
126- {"create_array" : create_array , "dtype" : dtype }))
127- set_global ("filled_" + name , eval (
128- "lambda val: make_%s(*[val]*%i)" % (name , count )))
129- set_global ("zeros_" + name , eval ("lambda: filled_%s(0)" % (name )))
130- set_global ("ones_" + name , eval ("lambda: filled_%s(1)" % (name )))
131-
132- vec_types [np .dtype (base_type ), count ] = dtype
133- vec_type_to_scalar_and_count [dtype ] = np .dtype (base_type ), count
135+ set_global (
136+ f"make_{ name } " ,
137+ eval ("lambda *args, **kwargs: "
138+ f"create_array(dtype, { count } , { padded_count } , *args, **kwargs)" ,
139+ {"create_array" : create_array , "dtype" : dtype }))
140+ set_global (
141+ f"filled_{ name } " ,
142+ eval (f"lambda val: make_{ name } (*[val]*{ count } )" ))
143+ set_global (f"zeros_{ name } " , eval (f"lambda: filled_{ name } (0)" ))
144+ set_global (f"ones_{ name } " , eval (f"lambda: filled_{ name } (1)" ))
145+
146+ base_dtype = np .dtype (base_type )
147+ vec_types [base_dtype , count ] = dtype
148+ vec_type_to_scalar_and_count [dtype ] = base_dtype , count
134149
135150 return vec_types , vec_type_to_scalar_and_count
136151
0 commit comments