7777 from numpy .typing import DTypeLike
7878
7979 import pyopencl
80- from arraycontext import Array
80+ from arraycontext import Array , ArrayContext
8181
82- from sumpy .array_context import PyOpenCLArrayContext
8382 from sumpy .expansion .local import LocalExpansionBase
8483 from sumpy .expansion .multipole import MultipoleExpansionBase
8584
@@ -114,7 +113,7 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
114113 strength_usage : Sequence [int ] | None
115114
116115 def __init__ (self ,
117- array_context : PyOpenCLArrayContext ,
116+ array_context : ArrayContext ,
118117 multipole_expansion_factory : MultipoleExpansionFromOrderFactory ,
119118 local_expansion_factory : LocalExpansionFromOrderFactory ,
120119 target_kernels : Sequence [Kernel ],
@@ -134,7 +133,7 @@ def __init__(self,
134133 """
135134 super ().__init__ ()
136135
137- self ._setup_actx : PyOpenCLArrayContext = array_context
136+ self ._setup_actx : ArrayContext = array_context
138137
139138 self .multipole_expansion_factory = multipole_expansion_factory
140139 self .local_expansion_factory = local_expansion_factory
@@ -422,7 +421,7 @@ def order_to_size(order: int):
422421 return build_csr_level_starts (self .level_orders , order_to_size ,
423422 level_starts = self .m2l_translation_class_level_start_box_nrs ())
424423
425- def multipole_expansion_zeros (self , actx : PyOpenCLArrayContext ) -> Array :
424+ def multipole_expansion_zeros (self , actx : ArrayContext ) -> Array :
426425 """Return an expansions array (which must support addition)
427426 capable of holding one multipole or local expansion for every
428427 box in the tree.
@@ -441,7 +440,7 @@ def local_expansion_zeros(self, actx) -> Array:
441440 dtype = self .dtype )
442441
443442 def m2l_translation_classes_dependent_data_zeros (
444- self , actx : PyOpenCLArrayContext ):
443+ self , actx : ArrayContext ):
445444 data_level_starts = (
446445 self .m2l_translation_classes_dependent_data_level_starts ())
447446 level_start_box_nrs = (
@@ -497,7 +496,7 @@ def order_to_size(order):
497496 level_starts = self .tree_level_start_box_nrs )
498497
499498 def m2l_preproc_mpole_expansion_zeros (
500- self , actx : PyOpenCLArrayContext , template_ary ):
499+ self , actx : ArrayContext , template_ary ):
501500 level_starts = self .m2l_preproc_mpole_expansions_level_starts ()
502501
503502 result = []
@@ -522,7 +521,7 @@ def m2l_preproc_mpole_expansions_view(self, mpole_exps, level):
522521 m2l_work_array_level_starts = m2l_preproc_mpole_expansions_level_starts
523522
524523 def output_zeros (self ,
525- actx : PyOpenCLArrayContext
524+ actx : ArrayContext
526525 ) -> obj_array .ObjectArray1D [Array ]:
527526 """Return a potentials array (which must support addition) capable of
528527 holding a potential value for each target in the tree. Note that
@@ -587,7 +586,7 @@ def box_target_list_kwargs(self):
587586
588587 # }}}
589588
590- def run_opencl_fft (self , actx : PyOpenCLArrayContext ,
589+ def run_opencl_fft (self , actx : ArrayContext ,
591590 input_vec , inverse , wait_for ):
592591 app = self .tree_indep .opencl_fft_app (input_vec .shape , input_vec .dtype ,
593592 inverse )
@@ -601,7 +600,7 @@ def run_opencl_fft(self, actx: PyOpenCLArrayContext,
601600 return result
602601
603602 def form_multipoles (self ,
604- actx : PyOpenCLArrayContext ,
603+ actx : ArrayContext ,
605604 level_start_source_box_nrs , source_boxes ,
606605 src_weight_vecs ):
607606 mpoles = self .multipole_expansion_zeros (actx )
@@ -635,7 +634,7 @@ def form_multipoles(self,
635634 return mpoles
636635
637636 def coarsen_multipoles (self ,
638- actx : PyOpenCLArrayContext ,
637+ actx : ArrayContext ,
639638 level_start_source_parent_box_nrs ,
640639 source_parent_boxes ,
641640 mpoles ):
@@ -689,7 +688,7 @@ def coarsen_multipoles(self,
689688 return mpoles
690689
691690 def eval_direct (self ,
692- actx : PyOpenCLArrayContext ,
691+ actx : ArrayContext ,
693692 target_boxes , source_box_starts ,
694693 source_box_lists , src_weight_vecs ):
695694 pot = self .output_zeros (actx )
@@ -791,7 +790,7 @@ def _add_m2l_precompute_kwargs(self, kwargs_for_m2l,
791790 self .translation_classes_data .from_sep_siblings_translation_classes
792791
793792 def multipole_to_local (self ,
794- actx : PyOpenCLArrayContext ,
793+ actx : ArrayContext ,
795794 level_start_target_box_nrs ,
796795 target_boxes , src_box_starts , src_box_lists ,
797796 mpole_exps ):
@@ -915,7 +914,7 @@ def multipole_to_local(self,
915914 return local_exps
916915
917916 def eval_multipoles (self ,
918- actx : PyOpenCLArrayContext ,
917+ actx : ArrayContext ,
919918 target_boxes_by_source_level , source_boxes_by_level , mpole_exps ):
920919 pot = self .output_zeros (actx )
921920
@@ -956,7 +955,7 @@ def eval_multipoles(self,
956955 return pot
957956
958957 def form_locals (self ,
959- actx : PyOpenCLArrayContext ,
958+ actx : ArrayContext ,
960959 level_start_target_or_target_parent_box_nrs ,
961960 target_or_target_parent_boxes , starts , lists , src_weight_vecs ):
962961 local_exps = self .local_expansion_zeros (actx )
@@ -997,7 +996,7 @@ def form_locals(self,
997996 return local_exps
998997
999998 def refine_locals (self ,
1000- actx : PyOpenCLArrayContext ,
999+ actx : ArrayContext ,
10011000 level_start_target_or_target_parent_box_nrs ,
10021001 target_or_target_parent_boxes ,
10031002 local_exps ):
@@ -1040,7 +1039,7 @@ def refine_locals(self,
10401039 return local_exps
10411040
10421041 def eval_locals (self ,
1043- actx : PyOpenCLArrayContext ,
1042+ actx : ArrayContext ,
10441043 level_start_target_box_nrs , target_boxes , local_exps ):
10451044 pot = self .output_zeros (actx )
10461045 level_start_target_box_nrs = actx .to_numpy (level_start_target_box_nrs )
@@ -1077,7 +1076,7 @@ def eval_locals(self,
10771076
10781077 return pot
10791078
1080- def finalize_potentials (self , actx : PyOpenCLArrayContext , potentials ):
1079+ def finalize_potentials (self , actx : ArrayContext , potentials ):
10811080 return potentials
10821081
10831082# }}}
0 commit comments