@@ -244,9 +244,10 @@ def kernel_runner(out: Array, *args: P.args, **kwargs: P.kwargs) -> cl.Event:
244244 assert queue is not None
245245
246246 knl = kernel_getter (out , * args , ** kwargs )
247- work_group_info = cast ( "int" , knl .get_work_group_info (
247+ work_group_info = knl .get_work_group_info (
248248 cl .kernel_work_group_info .WORK_GROUP_SIZE ,
249- queue .device ))
249+ queue .device )
250+ assert isinstance (work_group_info , int )
250251 gs , ls = out ._get_sizes (queue , work_group_info )
251252
252253 knl_args = (out , * args , out .size )
@@ -2706,10 +2707,11 @@ def make_func_for_chunk_size(chunk_size):
27062707 if start_i + chunk_size > vec_count :
27072708 knl = make_func_for_chunk_size (vec_count - start_i )
27082709
2709- gs , ls = dest_indices ._get_sizes (queue ,
2710- knl .get_work_group_info (
2711- cl .kernel_work_group_info .WORK_GROUP_SIZE ,
2712- queue .device ))
2710+ work_group_info = knl .get_work_group_info (
2711+ cl .kernel_work_group_info .WORK_GROUP_SIZE ,
2712+ queue .device )
2713+ assert isinstance (work_group_info , int )
2714+ gs , ls = dest_indices ._get_sizes (queue , work_group_info )
27132715
27142716 wait_for_this = (
27152717 * wait_for ,
0 commit comments