Skip to content

Commit 1cad1ba

Browse files
committed
feat(typing): small fixes in elementwise
1 parent a2b3416 commit 1cad1ba

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

pyopencl/_monkeypatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __getattr__(self, name: str):
221221
kernel_old_get_work_group_info = _cl.Kernel.get_work_group_info
222222

223223

224-
def kernel_set_arg_types(self: _cl.Kernel, arg_types):
224+
def kernel_set_arg_types(self: _cl.Kernel, arg_types) -> None:
225225
arg_types = tuple(arg_types)
226226

227227
# {{{ arg counting bug handling

pyopencl/elementwise.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import pyopencl as cl
3939
from pyopencl.tools import (
40+
Argument,
4041
DtypedArgument,
4142
KernelTemplateBase,
4243
ScalarArg,
@@ -55,7 +56,7 @@
5556

5657
def get_elwise_program(
5758
context: cl.Context,
58-
arguments: list[DtypedArgument],
59+
arguments: Sequence[Argument],
5960
operation: str, *,
6061
name: str = "elwise_kernel",
6162
options: Any = None,
@@ -123,27 +124,27 @@ def get_elwise_program(
123124

124125
def get_elwise_kernel_and_types(
125126
context: cl.Context,
126-
arguments: str | Sequence[DtypedArgument],
127+
arguments: str | Sequence[Argument],
127128
operation: str, *,
128129
name: str = "elwise_kernel",
129130
options: Any = None,
130131
preamble: str = "",
131132
use_range: bool = False,
132-
**kwargs: Any) -> tuple[cl.Kernel, list[DtypedArgument]]:
133+
**kwargs: Any) -> tuple[cl.Kernel, Sequence[Argument]]:
133134

134135
from pyopencl.tools import get_arg_offset_adjuster_code, parse_arg_list
135136
parsed_args = parse_arg_list(arguments, with_offset=True)
136137

137138
auto_preamble = kwargs.pop("auto_preamble", True)
138139

139-
pragmas = []
140-
includes = []
140+
pragmas: list[str] = []
141+
includes: list[str] = []
141142
have_double_pragma = False
142143
have_complex_include = False
143144

144145
if auto_preamble:
145146
for arg in parsed_args:
146-
if arg.dtype in [np.float64, np.complex128]:
147+
if arg.dtype.type in [np.float64, np.complex128]:
147148
if not have_double_pragma:
148149
pragmas.append("""
149150
#if __OPENCL_C_VERSION__ < 120
@@ -186,14 +187,14 @@ def get_elwise_kernel_and_types(
186187

187188
def get_elwise_kernel(
188189
context: cl.Context,
189-
arguments: str | list[DtypedArgument],
190+
arguments: str | Sequence[Argument],
190191
operation: str, *,
191192
name: str = "elwise_kernel",
192193
options: Any = None, **kwargs: Any) -> cl.Kernel:
193194
"""Return a L{pyopencl.Kernel} that performs the same scalar operation
194195
on one or several vectors.
195196
"""
196-
func, arguments = get_elwise_kernel_and_types(
197+
func, _arguments = get_elwise_kernel_and_types(
197198
context, arguments, operation,
198199
name=name, options=options, **kwargs)
199200

@@ -233,7 +234,7 @@ class ElementwiseKernel:
233234
def __init__(
234235
self,
235236
context: cl.Context,
236-
arguments: str | Sequence[DtypedArgument],
237+
arguments: str | Sequence[Argument],
237238
operation: str,
238239
name: str = "elwise_kernel",
239240
options: Any = None, **kwargs: Any) -> None:

0 commit comments

Comments
 (0)