Skip to content

Commit 683460d

Browse files
committed
fix some type annotations
1 parent 31ec76e commit 683460d

File tree

6 files changed

+61
-50
lines changed

6 files changed

+61
-50
lines changed

sumpy/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
"""
2525

2626
import os
27+
from collections.abc import Hashable
2728

29+
import loopy as lp
2830
from pytools.persistent_dict import WriteOncePersistentDict
2931

3032
from sumpy.e2e import (
@@ -59,7 +61,7 @@
5961
]
6062

6163

62-
code_cache = (
64+
code_cache: WriteOncePersistentDict[Hashable, lp.TranslationUnit] = (
6365
WriteOncePersistentDict(f"sumpy-code-cache-v7-{VERSION_TEXT}", safe_sync=False))
6466

6567

sumpy/fmm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
.. autoclass:: SumpyExpansionWrangler
3030
"""
3131

32+
from typing import TYPE_CHECKING
33+
3234
import numpy as np
3335

3436
from arraycontext import Array
@@ -58,6 +60,10 @@
5860
)
5961

6062

63+
if TYPE_CHECKING:
64+
import pyopencl
65+
66+
6167
# {{{ tree-independent data for wrangler
6268

6369
class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
@@ -731,7 +737,7 @@ def multipole_to_local(self,
731737
local_exps_view_func = self.local_expansions_view
732738

733739
for lev in range(self.tree.nlevels):
734-
wait_for = []
740+
wait_for: list[pyopencl.Event] = []
735741

736742
start, stop = level_start_target_box_nrs[lev:lev+2]
737743
if start == stop:

sumpy/p2p.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ def get_default_src_tgt_arguments(self):
183183
if self.exclude_self else [])
184184
+ gather_loopy_source_arguments(self.source_kernels))
185185

186-
def get_optimized_kernel(self, targets_is_obj_array, sources_is_obj_array):
186+
def get_optimized_kernel(self, *,
187+
targets_is_obj_array: bool = False,
188+
sources_is_obj_array: bool = False,
189+
**kwargs: Any) -> lp.TranslationUnit:
187190
# FIXME
188191
knl = self.get_kernel()
189192

@@ -194,10 +197,8 @@ def get_optimized_kernel(self, targets_is_obj_array, sources_is_obj_array):
194197

195198
knl = lp.split_iname(knl, "itgt", 1024, outer_tag="g.0")
196199
knl = self._allow_redundant_execution_of_knl_scaling(knl)
197-
knl = lp.set_options(knl,
198-
enforce_variable_access_ordered="no_check")
200+
knl = lp.set_options(knl, enforce_variable_access_ordered="no_check")
199201

200-
knl = register_optimization_preambles(knl, self.device)
201202
return knl
202203

203204

@@ -475,9 +476,11 @@ class P2PFromCSR(P2PBase):
475476
def default_name(self):
476477
return "p2p_from_csr"
477478

478-
def get_kernel(self,
479-
max_nsources_in_one_box: int, max_ntargets_in_one_box: int, *,
480-
work_items_per_group: int = 32, is_gpu: bool = False):
479+
def get_kernel(self, *,
480+
max_nsources_in_one_box: int = 32,
481+
max_ntargets_in_one_box: int = 32,
482+
work_items_per_group: int = 32,
483+
is_gpu: bool = False, **kwargs: Any) -> lp.TranslationUnit:
481484
loopy_insns, _result_names = self.get_loopy_insns_and_result_names()
482485
arguments = [
483486
*self.get_default_src_tgt_arguments(),
@@ -674,8 +677,10 @@ def get_kernel(self,
674677
"noutputs": len(self.target_kernels)},
675678
)
676679

677-
loopy_knl = lp.add_dtypes(
678-
loopy_knl, {"nsources": np.int32, "ntargets": np.int32})
680+
loopy_knl = lp.add_dtypes(loopy_knl, {
681+
"nsources": np.dtype(np.int32),
682+
"ntargets": np.dtype(np.int32),
683+
})
679684

680685
loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
681686
loopy_knl = lp.tag_inames(loopy_knl, "istrength*:unr")
@@ -687,19 +692,24 @@ def get_kernel(self,
687692

688693
return loopy_knl
689694

690-
def get_optimized_kernel(self,
691-
max_nsources_in_one_box: int,
692-
max_ntargets_in_one_box: int,
693-
strength_dtype: np.dtype[Any],
694-
source_dtype: np.dtype[Any],
695-
local_mem_size: int,
696-
is_gpu: bool):
695+
def get_optimized_kernel(self, *,
696+
max_nsources_in_one_box: int = 32,
697+
max_ntargets_in_one_box: int = 32,
698+
strength_dtype: np.dtype[Any] | None = None,
699+
source_dtype: np.dtype[Any] | None = None,
700+
local_mem_size: int = 32,
701+
is_gpu: bool = False, **kwargs) -> lp.TranslationUnit:
697702
if not is_gpu:
698-
knl = self.get_kernel(max_nsources_in_one_box,
699-
max_ntargets_in_one_box, is_gpu=is_gpu)
703+
knl = self.get_kernel(
704+
max_nsources_in_one_box=max_nsources_in_one_box,
705+
max_ntargets_in_one_box=max_ntargets_in_one_box,
706+
is_gpu=is_gpu)
700707
knl = lp.split_iname(knl, "itgt_box", 4, outer_tag="g.0")
701708
knl = self._allow_redundant_execution_of_knl_scaling(knl)
702709
else:
710+
assert strength_dtype is not None
711+
assert source_dtype is not None
712+
703713
dtype_size = np.dtype(strength_dtype).alignment
704714
work_items_per_group = min(256, max_ntargets_in_one_box)
705715
total_local_mem = max_nsources_in_one_box * \
@@ -708,8 +718,9 @@ def get_optimized_kernel(self,
708718
# can be scheduled at the same time for latency hiding
709719
nprefetch = (2 * total_local_mem - 1) // local_mem_size + 1
710720

711-
knl = self.get_kernel(max_nsources_in_one_box,
712-
max_ntargets_in_one_box,
721+
knl = self.get_kernel(
722+
max_nsources_in_one_box=max_nsources_in_one_box,
723+
max_ntargets_in_one_box=max_ntargets_in_one_box,
713724
work_items_per_group=work_items_per_group,
714725
is_gpu=is_gpu)
715726
knl = lp.tag_inames(knl, {"itgt_box": "g.0", "inner": "l.0"})
@@ -771,12 +782,8 @@ def get_optimized_kernel(self,
771782
knl = lp.add_inames_to_insn(knl,
772783
"inner", "id:init_* or id:*_scaling or id:src_box_insn_*")
773784
knl = lp.add_inames_to_insn(knl, "itgt_box", "id:*_scaling")
774-
# knl = lp.set_options(knl, write_code=True)
775-
776-
knl = lp.set_options(knl,
777-
enforce_variable_access_ordered="no_check")
778785

779-
knl = register_optimization_preambles(knl, self.device)
786+
knl = lp.set_options(knl, enforce_variable_access_ordered="no_check")
780787
return knl
781788

782789
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
@@ -786,8 +793,8 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
786793

787794
is_gpu = not is_cl_cpu(actx)
788795
if is_gpu:
789-
source_dtype = kwargs.get("sources")[0].dtype
790-
strength_dtype = kwargs.get("strength").dtype
796+
source_dtype = kwargs["sources"][0].dtype
797+
strength_dtype = kwargs["strength"].dtype
791798
else:
792799
# these are unused for not GPU and defeats the caching
793800
# set them to None to keep the caching across dtypes

sumpy/qbx.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"""
2828

2929
import logging
30+
from typing import Any
3031

3132
import numpy as np
3233

@@ -195,14 +196,14 @@ def get_kernel(self):
195196
raise NotImplementedError
196197

197198
def get_optimized_kernel(self, *,
198-
is_cpu: bool,
199-
targets_is_obj_array: bool,
200-
sources_is_obj_array: bool,
201-
centers_is_obj_array: bool,
199+
is_cpu: bool = True,
200+
targets_is_obj_array: bool = False,
201+
sources_is_obj_array: bool = False,
202+
centers_is_obj_array: bool = False,
202203
# Used by pytential to override the name of the loop to be
203204
# parallelized. In the case of QBX, that's the loop over QBX
204205
# targets (not global targets).
205-
itgt_name: str = "itgt"):
206+
itgt_name: str = "itgt", **kwargs: Any) -> lp.TranslationUnit:
206207
# FIXME specialize/tune for GPU/CPU
207208
loopy_knl = self.get_kernel()
208209

sumpy/tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,11 @@ def get_cache_key(self) -> tuple[Hashable, ...]:
429429
...
430430

431431
@abstractmethod
432-
def get_kernel(self, **kwargs: Any) -> lp.TranslationUnit:
432+
def get_kernel(self, **kwargs) -> lp.TranslationUnit:
433433
...
434434

435435
@abstractmethod
436-
def get_optimized_kernel(self, **kwargs: Any) -> lp.TranslationUnit:
436+
def get_optimized_kernel(self, **kwargs) -> lp.TranslationUnit:
437437
...
438438

439439
@memoize_method

sumpy/toys.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,7 @@ def eval(self, actx: PyOpenCLArrayContext, targets: np.ndarray) -> np.ndarray:
509509
def __neg__(self) -> PotentialSource:
510510
return -1*self
511511

512-
def __add__(self, other: Number_ish | PotentialSource
513-
) -> PotentialSource:
512+
def __add__(self, other: Number_ish | PotentialSource) -> PotentialSource:
514513
if isinstance(other, Number | np.number):
515514
other = ConstantPotential(self.toy_ctx, other)
516515
elif not isinstance(other, PotentialSource):
@@ -520,18 +519,14 @@ def __add__(self, other: Number_ish | PotentialSource
520519

521520
__radd__ = __add__
522521

523-
def __sub__(self,
524-
other: Number_ish | PotentialSource) -> PotentialSource:
522+
def __sub__(self, other: Number_ish | PotentialSource) -> PotentialSource:
525523
return self.__add__(-other)
526524

527-
def __rsub__(
528-
self,
529-
other: Number | np.number | PotentialSource
530-
) -> PotentialSource:
525+
def __rsub__(self, # type: ignore[misc]
526+
other: Number_ish | PotentialSource) -> PotentialSource:
531527
return (-self).__add__(other)
532528

533-
def __mul__(self,
534-
other: Number_ish | PotentialSource) -> PotentialSource:
529+
def __mul__(self, other: Number_ish | PotentialSource) -> PotentialSource:
535530
if isinstance(other, Number | np.number):
536531
other = ConstantPotential(self.toy_ctx, other)
537532
elif not isinstance(other, PotentialSource):
@@ -722,9 +717,9 @@ class Sum(PotentialExpressionNode):
722717
"""
723718

724719
def eval(self, actx: PyOpenCLArrayContext, targets: np.ndarray) -> np.ndarray:
725-
result = 0
720+
result = np.zeros(targets.shape[1])
726721
for psource in self.psources:
727-
result = result + psource.eval(actx, targets)
722+
result += psource.eval(actx, targets)
728723

729724
return result
730725

@@ -735,9 +730,9 @@ class Product(PotentialExpressionNode):
735730
"""
736731

737732
def eval(self, actx: PyOpenCLArrayContext, targets: np.ndarray) -> np.ndarray:
738-
result = 1
733+
result = np.ones(targets.shape[1])
739734
for psource in self.psources:
740-
result = result * psource.eval(actx, targets)
735+
result *= psource.eval(actx, targets)
741736

742737
return result
743738

0 commit comments

Comments
 (0)