Skip to content

Commit 28c2ac1

Browse files
committed
Fix/add some dtype-related typing
1 parent b34d6d5 commit 28c2ac1

File tree

4 files changed

+27
-13
lines changed

4 files changed

+27
-13
lines changed

pyopencl/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ def _copy(dest: Array, src: Array) -> cl.Kernel:
11251125
dest.context, dest.dtype, src.dtype)
11261126

11271127
def _new_like_me(self,
1128-
dtype: DTypeLike = None,
1128+
dtype: DTypeLike | None = None,
11291129
queue: cl.CommandQueue | None = None) -> Self:
11301130
if dtype is None:
11311131
dtype = self.dtype

pyopencl/capture_call.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def capture_kernel_call(
139139
cg("prg = cl.Program(ctx, CODE).build()")
140140
cg("knl = prg.%s" % kernel.function_name)
141141
if hasattr(kernel, "_scalar_arg_dtypes"):
142-
def strify_dtype(d: DTypeLike):
142+
def strify_dtype(d: DTypeLike | None):
143143
if d is None:
144144
return "None"
145145

pyopencl/compyte

Submodule compyte updated 1 file

pyopencl/reduction.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -656,13 +656,16 @@ def get_sum_kernel(ctx, dtype_out, dtype_in):
656656
)
657657

658658

659-
def _get_dot_expr(dtype_out, dtype_a, dtype_b, conjugate_first,
660-
has_double_support, index_expr="i"):
659+
def _get_dot_expr(
660+
dtype_out: np.dtype[Any] | None,
661+
dtype_a: np.dtype[Any],
662+
dtype_b: np.dtype[Any] | None,
663+
conjugate_first: bool,
664+
has_double_support: bool,
665+
index_expr: str = "i"
666+
):
661667
if dtype_b is None:
662-
if dtype_a is None:
663-
dtype_b = dtype_out
664-
else:
665-
dtype_b = dtype_a
668+
dtype_b = dtype_a
666669

667670
if dtype_out is None:
668671
from pyopencl.compyte.array import get_common_dtype
@@ -700,8 +703,13 @@ def _get_dot_expr(dtype_out, dtype_a, dtype_b, conjugate_first,
700703

701704

702705
@context_dependent_memoize
703-
def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None,
704-
conjugate_first=False):
706+
def get_dot_kernel(
707+
ctx: cl.Context,
708+
dtype_out: np.dtype[Any] | None,
709+
dtype_a: np.dtype[Any],
710+
dtype_b: np.dtype[Any],
711+
conjugate_first: bool = False
712+
):
705713
from pyopencl.characterize import has_double_support
706714
map_expr, dtype_out, dtype_b = _get_dot_expr(
707715
dtype_out, dtype_a, dtype_b, conjugate_first,
@@ -726,8 +734,14 @@ def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None,
726734

727735

728736
@context_dependent_memoize
729-
def get_subset_dot_kernel(ctx, dtype_out, dtype_subset, dtype_a=None, dtype_b=None,
730-
conjugate_first=False):
737+
def get_subset_dot_kernel(
738+
ctx: cl.Context,
739+
dtype_out: np.dtype[Any] | None,
740+
dtype_subset: np.dtype[Any],
741+
dtype_a: np.dtype[Any],
742+
dtype_b: np.dtype[Any],
743+
conjugate_first: bool = False
744+
):
731745
from pyopencl.characterize import has_double_support
732746
map_expr, dtype_out, dtype_b = _get_dot_expr(
733747
dtype_out, dtype_a, dtype_b, conjugate_first,

0 commit comments

Comments
 (0)