Skip to content

Commit 039119e

Browse files
committed
Use ArrayContext+assert instead of PyOpenCLActx in annotations
1 parent 680aef7 commit 039119e

File tree

11 files changed

+102
-73
lines changed

11 files changed

+102
-73
lines changed

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"CallInstruction": "class:loopy.kernel.instruction.CallInstruction",
5656
# arraycontext
5757
"Array": "obj:arraycontext.Array",
58+
"ArrayContext": "class:arraycontext.ArrayContext",
5859
# boxtree
5960
"FMMTraversalInfo": "class:boxtree.traversal.FMMTraversalInfo",
6061
# sumpy

sumpy/array_context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from numpy.typing import DTypeLike
4242

43+
from arraycontext import ArrayContext
4344
from loopy import TranslationUnit
4445
from loopy.codegen import PreambleInfo
4546
from pytools.tag import ToTagSetConvertible
@@ -112,7 +113,10 @@ def transform_loopy_program(self, t_unit: TranslationUnit):
112113
return t_unit
113114

114115

115-
def is_cl_cpu(actx: PyOpenCLArrayContext) -> bool:
116+
def is_cl_cpu(actx: ArrayContext) -> bool:
117+
if not isinstance(actx, PyOpenCLArrayContext):
118+
return False
119+
116120
import pyopencl as cl
117121
return all(dev.type & cl.device_type.CPU for dev in actx.context.devices)
118122

sumpy/distributed.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333

3434

3535
if TYPE_CHECKING:
36-
from sumpy.array_context import PyOpenCLArrayContext
36+
from arraycontext import ArrayContext
3737

3838

3939
class DistributedSumpyExpansionWrangler(
4040
DistributedExpansionWranglerMixin, SumpyExpansionWrangler):
4141
def __init__(
42-
self, actx: PyOpenCLArrayContext,
42+
self, actx: ArrayContext,
4343
comm, tree_indep, local_traversal, global_traversal,
4444
dtype, fmm_level_to_order, communicate_mpoles_via_allreduce=False,
4545
**kwargs):
@@ -53,7 +53,7 @@ def __init__(
5353
self.communicate_mpoles_via_allreduce = communicate_mpoles_via_allreduce
5454

5555
def distribute_source_weights(self,
56-
actx: PyOpenCLArrayContext, src_weight_vecs, src_idx_all_ranks):
56+
actx: ArrayContext, src_weight_vecs, src_idx_all_ranks):
5757
src_weight_vecs_host = [
5858
actx.to_numpy(src_weight) for src_weight in src_weight_vecs
5959
]
@@ -68,7 +68,7 @@ def distribute_source_weights(self,
6868
return local_src_weight_vecs_device
6969

7070
def gather_potential_results(self,
71-
actx: PyOpenCLArrayContext, potentials, tgt_idx_all_ranks):
71+
actx: ArrayContext, potentials, tgt_idx_all_ranks):
7272
potentials_host_vec = [
7373
actx.to_numpy(potentials_dev) for potentials_dev in potentials
7474
]
@@ -109,7 +109,7 @@ def reorder(x):
109109
return None
110110

111111
def communicate_mpoles(self,
112-
actx: PyOpenCLArrayContext, mpole_exps, return_stats=False):
112+
actx: ArrayContext, mpole_exps, return_stats=False):
113113
mpole_exps_host = actx.to_numpy(mpole_exps)
114114
stats = super().communicate_mpoles(actx, mpole_exps_host, return_stats)
115115
mpole_exps[:] = mpole_exps_host

sumpy/e2e.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import logging
2727
from abc import ABC, abstractmethod
28+
from typing import TYPE_CHECKING
2829

2930
import numpy as np
3031
from typing_extensions import override
@@ -34,10 +35,14 @@
3435
from pytools import memoize_method
3536

3637
import sumpy.symbolic as sym
37-
from sumpy.array_context import PyOpenCLArrayContext, make_loopy_program
38+
from sumpy.array_context import make_loopy_program
3839
from sumpy.tools import KernelCacheMixin, to_complex_dtype
3940

4041

42+
if TYPE_CHECKING:
43+
from arraycontext import ArrayContext
44+
45+
4146
logger = logging.getLogger(__name__)
4247

4348

@@ -267,7 +272,7 @@ def get_optimized_kernel(self):
267272

268273
return knl
269274

270-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
275+
def __call__(self, actx: ArrayContext, **kwargs):
271276
"""
272277
:arg src_expansions:
273278
:arg src_box_starts:
@@ -511,7 +516,7 @@ def get_optimized_kernel(self, result_dtype):
511516

512517
return knl
513518

514-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
519+
def __call__(self, actx: ArrayContext, **kwargs):
515520
"""
516521
:arg src_expansions:
517522
:arg src_box_starts:
@@ -624,7 +629,7 @@ def get_optimized_kernel(self, result_dtype):
624629

625630
return knl
626631

627-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
632+
def __call__(self, actx: ArrayContext, **kwargs):
628633
"""
629634
:arg src_rscale:
630635
:arg translation_classes_level_start:
@@ -729,7 +734,7 @@ def get_optimized_kernel(self, result_dtype):
729734
knl = optimization(knl)
730735
return knl
731736

732-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
737+
def __call__(self, actx: ArrayContext, **kwargs):
733738
"""
734739
:arg src_expansions
735740
:arg preprocessed_src_expansions
@@ -830,7 +835,7 @@ def get_optimized_kernel(self, result_dtype):
830835
knl = lp.add_inames_for_unused_hw_axes(knl)
831836
return knl
832837

833-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
838+
def __call__(self, actx: ArrayContext, **kwargs):
834839
"""
835840
:arg tgt_expansions
836841
:arg tgt_expansions_before_postprocessing
@@ -943,7 +948,7 @@ def get_kernel(self):
943948

944949
return loopy_knl
945950

946-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
951+
def __call__(self, actx: ArrayContext, **kwargs):
947952
"""
948953
:arg src_expansions:
949954
:arg src_box_starts:
@@ -1050,7 +1055,7 @@ def get_kernel(self):
10501055

10511056
return loopy_knl
10521057

1053-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
1058+
def __call__(self, actx: ArrayContext, **kwargs):
10541059
"""
10551060
:arg src_expansions:
10561061
:arg src_box_starts:

sumpy/e2p.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,22 @@
2424
"""
2525

2626
from abc import ABC, abstractmethod
27+
from typing import TYPE_CHECKING
2728

2829
import numpy as np
2930

3031
import loopy as lp
3132
import pytools.obj_array as obj_array
3233
from loopy.version import MOST_RECENT_LANGUAGE_VERSION # noqa: F401
3334

34-
from sumpy.array_context import PyOpenCLArrayContext, make_loopy_program
35+
from sumpy.array_context import make_loopy_program
3536
from sumpy.tools import KernelCacheMixin, gather_loopy_arguments
3637

3738

39+
if TYPE_CHECKING:
40+
from arraycontext import ArrayContext
41+
42+
3843
__doc__ = """
3944
4045
Expansion-to-particle
@@ -203,7 +208,7 @@ def get_optimized_kernel(self):
203208

204209
return knl
205210

206-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
211+
def __call__(self, actx: ArrayContext, **kwargs):
207212
"""
208213
:arg expansions:
209214
:arg target_boxes:
@@ -331,7 +336,7 @@ def get_optimized_kernel(self):
331336

332337
return knl
333338

334-
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
339+
def __call__(self, actx: ArrayContext, **kwargs):
335340
centers = kwargs.pop("centers")
336341
# "1" may be passed for rscale, which won't have its type
337342
# meaningfully inferred. Make the type of rscale explicit.

sumpy/fmm.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@
7777
from numpy.typing import DTypeLike
7878

7979
import pyopencl
80-
from arraycontext import Array
80+
from arraycontext import Array, ArrayContext
8181

82-
from sumpy.array_context import PyOpenCLArrayContext
8382
from sumpy.expansion.local import LocalExpansionBase
8483
from sumpy.expansion.multipole import MultipoleExpansionBase
8584

@@ -114,7 +113,7 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
114113
strength_usage: Sequence[int] | None
115114

116115
def __init__(self,
117-
array_context: PyOpenCLArrayContext,
116+
array_context: ArrayContext,
118117
multipole_expansion_factory: MultipoleExpansionFromOrderFactory,
119118
local_expansion_factory: LocalExpansionFromOrderFactory,
120119
target_kernels: Sequence[Kernel],
@@ -134,7 +133,7 @@ def __init__(self,
134133
"""
135134
super().__init__()
136135

137-
self._setup_actx: PyOpenCLArrayContext = array_context
136+
self._setup_actx: ArrayContext = array_context
138137

139138
self.multipole_expansion_factory = multipole_expansion_factory
140139
self.local_expansion_factory = local_expansion_factory
@@ -422,7 +421,7 @@ def order_to_size(order: int):
422421
return build_csr_level_starts(self.level_orders, order_to_size,
423422
level_starts=self.m2l_translation_class_level_start_box_nrs())
424423

425-
def multipole_expansion_zeros(self, actx: PyOpenCLArrayContext) -> Array:
424+
def multipole_expansion_zeros(self, actx: ArrayContext) -> Array:
426425
"""Return an expansions array (which must support addition)
427426
capable of holding one multipole or local expansion for every
428427
box in the tree.
@@ -441,7 +440,7 @@ def local_expansion_zeros(self, actx) -> Array:
441440
dtype=self.dtype)
442441

443442
def m2l_translation_classes_dependent_data_zeros(
444-
self, actx: PyOpenCLArrayContext):
443+
self, actx: ArrayContext):
445444
data_level_starts = (
446445
self.m2l_translation_classes_dependent_data_level_starts())
447446
level_start_box_nrs = (
@@ -497,7 +496,7 @@ def order_to_size(order):
497496
level_starts=self.tree_level_start_box_nrs)
498497

499498
def m2l_preproc_mpole_expansion_zeros(
500-
self, actx: PyOpenCLArrayContext, template_ary):
499+
self, actx: ArrayContext, template_ary):
501500
level_starts = self.m2l_preproc_mpole_expansions_level_starts()
502501

503502
result = []
@@ -522,7 +521,7 @@ def m2l_preproc_mpole_expansions_view(self, mpole_exps, level):
522521
m2l_work_array_level_starts = m2l_preproc_mpole_expansions_level_starts
523522

524523
def output_zeros(self,
525-
actx: PyOpenCLArrayContext
524+
actx: ArrayContext
526525
) -> obj_array.ObjectArray1D[Array]:
527526
"""Return a potentials array (which must support addition) capable of
528527
holding a potential value for each target in the tree. Note that
@@ -587,7 +586,7 @@ def box_target_list_kwargs(self):
587586

588587
# }}}
589588

590-
def run_opencl_fft(self, actx: PyOpenCLArrayContext,
589+
def run_opencl_fft(self, actx: ArrayContext,
591590
input_vec, inverse, wait_for):
592591
app = self.tree_indep.opencl_fft_app(input_vec.shape, input_vec.dtype,
593592
inverse)
@@ -601,7 +600,7 @@ def run_opencl_fft(self, actx: PyOpenCLArrayContext,
601600
return result
602601

603602
def form_multipoles(self,
604-
actx: PyOpenCLArrayContext,
603+
actx: ArrayContext,
605604
level_start_source_box_nrs, source_boxes,
606605
src_weight_vecs):
607606
mpoles = self.multipole_expansion_zeros(actx)
@@ -635,7 +634,7 @@ def form_multipoles(self,
635634
return mpoles
636635

637636
def coarsen_multipoles(self,
638-
actx: PyOpenCLArrayContext,
637+
actx: ArrayContext,
639638
level_start_source_parent_box_nrs,
640639
source_parent_boxes,
641640
mpoles):
@@ -689,7 +688,7 @@ def coarsen_multipoles(self,
689688
return mpoles
690689

691690
def eval_direct(self,
692-
actx: PyOpenCLArrayContext,
691+
actx: ArrayContext,
693692
target_boxes, source_box_starts,
694693
source_box_lists, src_weight_vecs):
695694
pot = self.output_zeros(actx)
@@ -791,7 +790,7 @@ def _add_m2l_precompute_kwargs(self, kwargs_for_m2l,
791790
self.translation_classes_data.from_sep_siblings_translation_classes
792791

793792
def multipole_to_local(self,
794-
actx: PyOpenCLArrayContext,
793+
actx: ArrayContext,
795794
level_start_target_box_nrs,
796795
target_boxes, src_box_starts, src_box_lists,
797796
mpole_exps):
@@ -915,7 +914,7 @@ def multipole_to_local(self,
915914
return local_exps
916915

917916
def eval_multipoles(self,
918-
actx: PyOpenCLArrayContext,
917+
actx: ArrayContext,
919918
target_boxes_by_source_level, source_boxes_by_level, mpole_exps):
920919
pot = self.output_zeros(actx)
921920

@@ -956,7 +955,7 @@ def eval_multipoles(self,
956955
return pot
957956

958957
def form_locals(self,
959-
actx: PyOpenCLArrayContext,
958+
actx: ArrayContext,
960959
level_start_target_or_target_parent_box_nrs,
961960
target_or_target_parent_boxes, starts, lists, src_weight_vecs):
962961
local_exps = self.local_expansion_zeros(actx)
@@ -997,7 +996,7 @@ def form_locals(self,
997996
return local_exps
998997

999998
def refine_locals(self,
1000-
actx: PyOpenCLArrayContext,
999+
actx: ArrayContext,
10011000
level_start_target_or_target_parent_box_nrs,
10021001
target_or_target_parent_boxes,
10031002
local_exps):
@@ -1040,7 +1039,7 @@ def refine_locals(self,
10401039
return local_exps
10411040

10421041
def eval_locals(self,
1043-
actx: PyOpenCLArrayContext,
1042+
actx: ArrayContext,
10441043
level_start_target_box_nrs, target_boxes, local_exps):
10451044
pot = self.output_zeros(actx)
10461045
level_start_target_box_nrs = actx.to_numpy(level_start_target_box_nrs)
@@ -1077,7 +1076,7 @@ def eval_locals(self,
10771076

10781077
return pot
10791078

1080-
def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials):
1079+
def finalize_potentials(self, actx: ArrayContext, potentials):
10811080
return potentials
10821081

10831082
# }}}

0 commit comments

Comments
 (0)