@@ -656,13 +656,16 @@ def get_sum_kernel(ctx, dtype_out, dtype_in):
656656 )
657657
658658
659- def _get_dot_expr (dtype_out , dtype_a , dtype_b , conjugate_first ,
660- has_double_support , index_expr = "i" ):
659+ def _get_dot_expr (
660+ dtype_out : np .dtype [Any ] | None ,
661+ dtype_a : np .dtype [Any ],
662+ dtype_b : np .dtype [Any ] | None ,
663+ conjugate_first : bool ,
664+ has_double_support : bool ,
665+ index_expr : str = "i"
666+ ):
661667 if dtype_b is None :
662- if dtype_a is None :
663- dtype_b = dtype_out
664- else :
665- dtype_b = dtype_a
668+ dtype_b = dtype_a
666669
667670 if dtype_out is None :
668671 from pyopencl .compyte .array import get_common_dtype
@@ -700,8 +703,13 @@ def _get_dot_expr(dtype_out, dtype_a, dtype_b, conjugate_first,
700703
701704
702705@context_dependent_memoize
703- def get_dot_kernel (ctx , dtype_out , dtype_a = None , dtype_b = None ,
704- conjugate_first = False ):
706+ def get_dot_kernel (
707+ ctx : cl .Context ,
708+ dtype_out : np .dtype [Any ] | None ,
709+ dtype_a : np .dtype [Any ],
710+ dtype_b : np .dtype [Any ],
711+ conjugate_first : bool = False
712+ ):
705713 from pyopencl .characterize import has_double_support
706714 map_expr , dtype_out , dtype_b = _get_dot_expr (
707715 dtype_out , dtype_a , dtype_b , conjugate_first ,
@@ -726,8 +734,14 @@ def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None,
726734
727735
728736@context_dependent_memoize
729- def get_subset_dot_kernel (ctx , dtype_out , dtype_subset , dtype_a = None , dtype_b = None ,
730- conjugate_first = False ):
737+ def get_subset_dot_kernel (
738+ ctx : cl .Context ,
739+ dtype_out : np .dtype [Any ] | None ,
740+ dtype_subset : np .dtype [Any ],
741+ dtype_a : np .dtype [Any ],
742+ dtype_b : np .dtype [Any ],
743+ conjugate_first : bool = False
744+ ):
731745 from pyopencl .characterize import has_double_support
732746 map_expr , dtype_out , dtype_b = _get_dot_expr (
733747 dtype_out , dtype_a , dtype_b , conjugate_first ,
0 commit comments