Skip to content

Commit 57dc042

Browse files
committed
get_cached_kernel_executor -> get_cached_kernel
1 parent 494220f commit 57dc042

File tree

6 files changed

+20
-21
lines changed

6 files changed

+20
-21
lines changed

sumpy/e2e.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
282282
src_rscale = centers.dtype.type(kwargs.pop("src_rscale"))
283283
tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
284284

285-
knl = self.get_cached_kernel_executor()
285+
knl = self.get_cached_kernel()
286286
result = actx.call_loopy(
287287
knl,
288288
centers=centers,
@@ -527,7 +527,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
527527
tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
528528
src_expansions = kwargs.pop("src_expansions")
529529

530-
knl = self.get_cached_kernel_executor(result_dtype=src_expansions.dtype)
530+
knl = self.get_cached_kernel(result_dtype=src_expansions.dtype)
531531
result = actx.call_loopy(
532532
knl,
533533
src_expansions=src_expansions,
@@ -641,7 +641,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
641641
"m2l_translation_classes_dependent_data")
642642
result_dtype = m2l_translation_classes_dependent_data.dtype
643643

644-
knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
644+
knl = self.get_cached_kernel(result_dtype=result_dtype)
645645
result = actx.call_loopy(
646646
knl,
647647
src_rscale=src_rscale,
@@ -737,7 +737,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
737737
preprocessed_src_expansions = kwargs.pop("preprocessed_src_expansions")
738738
result_dtype = preprocessed_src_expansions.dtype
739739

740-
knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
740+
knl = self.get_cached_kernel(result_dtype=result_dtype)
741741
result = actx.call_loopy(
742742
knl,
743743
preprocessed_src_expansions=preprocessed_src_expansions,
@@ -838,7 +838,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
838838
tgt_expansions = kwargs.pop("tgt_expansions")
839839
result_dtype = tgt_expansions.dtype
840840

841-
knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
841+
knl = self.get_cached_kernel(result_dtype=result_dtype)
842842
result = actx.call_loopy(
843843
knl,
844844
tgt_expansions=tgt_expansions,
@@ -958,7 +958,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
958958
src_rscale = centers.dtype.type(kwargs.pop("src_rscale"))
959959
tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
960960

961-
knl = self.get_cached_kernel_executor()
961+
knl = self.get_cached_kernel()
962962
result = actx.call_loopy(
963963
knl,
964964
centers=centers,
@@ -1065,7 +1065,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
10651065
src_rscale = centers.dtype.type(kwargs.pop("src_rscale"))
10661066
tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
10671067

1068-
knl = self.get_cached_kernel_executor()
1068+
knl = self.get_cached_kernel()
10691069
result = actx.call_loopy(
10701070
knl,
10711071
centers=centers,

sumpy/e2p.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
218218
# meaningfully inferred. Make the type of rscale explicit.
219219
rscale = centers.dtype.type(kwargs.pop("rscale"))
220220

221-
knl = self.get_cached_kernel_executor()
221+
knl = self.get_cached_kernel()
222222
result = actx.call_loopy(
223223
knl,
224224
centers=centers, rscale=rscale, **kwargs)
@@ -337,7 +337,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
337337
# meaningfully inferred. Make the type of rscale explicit.
338338
rscale = centers.dtype.type(kwargs.pop("rscale"))
339339

340-
knl = self.get_cached_kernel_executor()
340+
knl = self.get_cached_kernel()
341341
result = actx.call_loopy(
342342
knl,
343343
centers=centers,

sumpy/p2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
133133
dtype = centers[0].dtype if is_obj_array_like(centers) else centers.dtype
134134
rscale = dtype.type(kwargs.pop("rscale"))
135135

136-
knl = self.get_cached_kernel_executor(
136+
knl = self.get_cached_kernel(
137137
sources_is_obj_array=is_obj_array_like(sources),
138138
centers_is_obj_array=is_obj_array_like(centers))
139139

sumpy/p2p.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def __call__(self,
272272
strength: Sequence[Array],
273273
**kwargs: Any,
274274
) -> tuple[cl.Event, Sequence[Array]]:
275-
knl = self.get_cached_kernel_executor(
275+
knl = self.get_cached_kernel(
276276
targets_is_obj_array=is_obj_array_like(targets),
277277
sources_is_obj_array=is_obj_array_like(sources))
278278

@@ -345,7 +345,7 @@ def __call__(self,
345345
sources: ObjectArray1D[Array] | Array,
346346
**kwargs: Any,
347347
) -> Sequence[Array]:
348-
knl = self.get_cached_kernel_executor(
348+
knl = self.get_cached_kernel(
349349
targets_is_obj_array=is_obj_array_like(targets),
350350
sources_is_obj_array=is_obj_array_like(sources))
351351

@@ -439,7 +439,6 @@ def get_optimized_kernel(self, targets_is_obj_array, sources_is_obj_array):
439439
knl = self._allow_redundant_execution_of_knl_scaling(knl)
440440
knl = lp.set_options(knl,
441441
enforce_variable_access_ordered="no_check")
442-
knl = register_optimization_preambles(knl, self.device)
443442

444443
return knl
445444

@@ -467,7 +466,7 @@ def __call__(self,
467466
:returns: a one-dimensional array of interactions, for each index pair
468467
in (*srcindices*, *tgtindices*)
469468
"""
470-
knl = self.get_cached_kernel_executor(
469+
knl = self.get_cached_kernel(
471470
targets_is_obj_array=is_obj_array_like(targets),
472471
sources_is_obj_array=is_obj_array_like(sources))
473472

@@ -824,7 +823,7 @@ def __call__(self,
824823
source_dtype = None
825824
strength_dtype = None
826825

827-
knl = self.get_cached_kernel_executor(
826+
knl = self.get_cached_kernel(
828827
max_nsources_in_one_box=max_nsources_in_one_box,
829828
max_ntargets_in_one_box=max_ntargets_in_one_box,
830829
local_mem_size=actx.queue.device.local_mem_size,

sumpy/qbx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __call__(self, actx: PyOpenCLArrayContext,
317317
already multiplied in.
318318
"""
319319

320-
knl = self.get_cached_kernel_executor(
320+
knl = self.get_cached_kernel(
321321
is_cpu=is_cl_cpu(actx),
322322
targets_is_obj_array=is_obj_array_like(targets),
323323
sources_is_obj_array=is_obj_array_like(sources),
@@ -396,7 +396,7 @@ def get_kernel(self):
396396

397397
def __call__(self, actx: PyOpenCLArrayContext,
398398
targets, sources, centers, expansion_radii, **kwargs):
399-
knl = self.get_cached_kernel_executor(
399+
knl = self.get_cached_kernel(
400400
is_cpu=is_cl_cpu(actx),
401401
targets_is_obj_array=is_obj_array_like(targets),
402402
sources_is_obj_array=is_obj_array_like(sources),
@@ -525,7 +525,7 @@ def __call__(self, actx: PyOpenCLArrayContext,
525525
in (*srcindices*, *tgtindices*)
526526
"""
527527

528-
knl = self.get_cached_kernel_executor(
528+
knl = self.get_cached_kernel(
529529
targets_is_obj_array=is_obj_array_like(targets),
530530
sources_is_obj_array=is_obj_array_like(sources),
531531
centers_is_obj_array=is_obj_array_like(centers))

sumpy/tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def get_optimized_kernel(self, **kwargs: Any) -> lp.TranslationUnit:
454454
...
455455

456456
@memoize_method
457-
def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
457+
def get_cached_kernel(self, **kwargs) -> lp.TranslationUnit:
458458
from sumpy import CACHING_ENABLED, NO_CACHE_KERNELS, OPT_ENABLED, code_cache
459459

460460
if CACHING_ENABLED and not (
@@ -472,7 +472,7 @@ def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
472472
try:
473473
result = code_cache[cache_key]
474474
logger.debug("%s: kernel cache hit [key=%s]", self.name, cache_key)
475-
return result.executor(self.context)
475+
return result
476476
except KeyError:
477477
pass
478478

@@ -493,7 +493,7 @@ def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
493493
NO_CACHE_KERNELS and self.name in NO_CACHE_KERNELS):
494494
code_cache.store_if_not_present(cache_key, knl)
495495

496-
return knl.executor(self.context)
496+
return knl
497497

498498
@staticmethod
499499
def _allow_redundant_execution_of_knl_scaling(

0 commit comments

Comments
 (0)