2222
2323from pytential .qbx import QBXLayerPotentialSource
2424from arraycontext import PyOpenCLArrayContext , unflatten
25- from typing import Optional , Any
25+ from typing import Any , Dict
2626from dataclasses import dataclass
2727import numpy as np
2828import pyopencl as cl
@@ -135,7 +135,7 @@ def __call__(self, queue, tree, **kwargs):
135135
136136
137137def broadcast_global_geometry_data (
138- comm , actx : PyOpenCLArrayContext , traversal_builder , global_geometry_data ):
138+ comm , actx , traversal_builder , global_geometry_data ):
139139 """Broadcasts useful fields of global geometry data from the root rank to the
140140 worker ranks, so that each rank can form local geometry data independently.
141141
@@ -212,7 +212,7 @@ def broadcast_global_geometry_data(
212212
213213
214214def compute_local_geometry_data (
215- actx : PyOpenCLArrayContext , comm , global_geometry_data , boxes_time ,
215+ actx , comm , global_geometry_data , boxes_time ,
216216 traversal_builder ):
217217 """Compute the local geometry data of the current rank from the global geometry
218218 data.
@@ -572,10 +572,7 @@ def distribute_geo_data(comm, actx, insn, bound_expr, evaluate,
572572
573573
574574class DistributedQBXLayerPotentialSource (QBXLayerPotentialSource ):
575- def __init__ (self , comm , cl_context , * args ,
576- _use_target_specific_qbx : Optional [bool ] = None ,
577- fmm_backend : str = "fmmlib" ,
578- ** kwargs ):
575+ def __init__ (self , comm , cl_context , * args , ** kwargs ):
579576 """
580577 :arg comm: MPI communicator.
581578 :arg cl_context: This argument is necessary because although the root rank
@@ -584,6 +581,8 @@ def __init__(self, comm, cl_context, *args,
584581
585582 `*args` and `**kwargs` will be forwarded to
586583 `QBXLayerPotentialSource.__init__` on the root rank.
584+
585+ Currently, `fmm_backend` has to be set to `"fmmlib"`.
587586 """
588587 self .comm = comm
589588 self ._cl_context = cl_context
@@ -600,17 +599,14 @@ def __init__(self, comm, cl_context, *args,
600599 "distributed implementation" )
601600
602601 # Only fmmlib is supported
603- assert fmm_backend == "fmmlib"
602+ assert kwargs [ " fmm_backend" ] == "fmmlib"
604603
605604 if self .comm .Get_rank () == 0 :
606- super ().__init__ (
607- * args ,
608- _use_target_specific_qbx = _use_target_specific_qbx ,
609- fmm_backend = fmm_backend ,
610- ** kwargs )
605+ super ().__init__ (* args , ** kwargs )
611606 else :
612- self ._use_target_specific_qbx = _use_target_specific_qbx
613- self .fmm_backend = fmm_backend
607+ self .fmm_backend = "fmmlib"
608+ self ._use_target_specific_qbx = kwargs .get (
609+ "_use_target_specific_qbx" , None )
614610 self .qbx_order = kwargs .get ("qbx_order" , None )
615611 self .fmm_level_to_order = kwargs .get ("fmm_level_to_order" , None )
616612 self .expansion_factory = kwargs .get ("expansion_factory" , None )
@@ -727,6 +723,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
727723
728724 user_source_ids = None
729725 if self .comm .Get_rank () == 0 :
726+ assert global_geo_data_device is not None
730727 user_source_ids = global_geo_data_device .tree ().user_source_ids
731728
732729 kernel_extra_kwargs , source_extra_kwargs = (
@@ -767,6 +764,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
767764 self ._use_target_specific_qbx )
768765
769766 if self .comm .Get_rank () == 0 :
767+ assert global_geo_data_device is not None
770768 from pytential .qbx .geometry import target_state
771769 if actx .to_numpy (actx .np .any (
772770 actx .thaw (global_geo_data_device .user_target_to_center ())
@@ -784,11 +782,12 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
784782 # }}}
785783
786784 # Execute global QBX.
787- timing_data = {}
785+ timing_data : Dict [ str , Any ] = {}
788786 all_potentials_on_every_target = drive_dfmm (
789787 self .comm , flat_strengths , wrangler , timing_data )
790788
791789 if self .comm .Get_rank () == 0 :
790+ assert global_geo_data_device is not None
792791 results = []
793792
794793 for o in insn .outputs :
@@ -830,6 +829,7 @@ def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
830829
831830 # {{{ Distribute source weights
832831
832+ template_ary = None
833833 if current_rank == 0 :
834834 template_ary = src_weight_vecs [0 ]
835835
0 commit comments