Skip to content

Commit add1974

Browse files
committed
Placate pylint and mypy
1 parent 5cfb872 commit add1974

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

pytential/qbx/distributed.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from pytential.qbx import QBXLayerPotentialSource
2424
from arraycontext import PyOpenCLArrayContext, unflatten
25-
from typing import Optional, Any
25+
from typing import Any, Dict
2626
from dataclasses import dataclass
2727
import numpy as np
2828
import pyopencl as cl
@@ -135,7 +135,7 @@ def __call__(self, queue, tree, **kwargs):
135135

136136

137137
def broadcast_global_geometry_data(
138-
comm, actx: PyOpenCLArrayContext, traversal_builder, global_geometry_data):
138+
comm, actx, traversal_builder, global_geometry_data):
139139
"""Broadcasts useful fields of global geometry data from the root rank to the
140140
worker ranks, so that each rank can form local geometry data independently.
141141
@@ -212,7 +212,7 @@ def broadcast_global_geometry_data(
212212

213213

214214
def compute_local_geometry_data(
215-
actx: PyOpenCLArrayContext, comm, global_geometry_data, boxes_time,
215+
actx, comm, global_geometry_data, boxes_time,
216216
traversal_builder):
217217
"""Compute the local geometry data of the current rank from the global geometry
218218
data.
@@ -572,10 +572,7 @@ def distribute_geo_data(comm, actx, insn, bound_expr, evaluate,
572572

573573

574574
class DistributedQBXLayerPotentialSource(QBXLayerPotentialSource):
575-
def __init__(self, comm, cl_context, *args,
576-
_use_target_specific_qbx: Optional[bool] = None,
577-
fmm_backend: str = "fmmlib",
578-
**kwargs):
575+
def __init__(self, comm, cl_context, *args, **kwargs):
579576
"""
580577
:arg comm: MPI communicator.
581578
:arg cl_context: This argument is necessary because although the root rank
@@ -584,6 +581,8 @@ def __init__(self, comm, cl_context, *args,
584581
585582
`*args` and `**kwargs` will be forwarded to
586583
`QBXLayerPotentialSource.__init__` on the root rank.
584+
585+
Currently, `fmm_backend` has to be set to `"fmmlib"`.
587586
"""
588587
self.comm = comm
589588
self._cl_context = cl_context
@@ -600,17 +599,14 @@ def __init__(self, comm, cl_context, *args,
600599
"distributed implementation")
601600

602601
# Only fmmlib is supported
603-
assert fmm_backend == "fmmlib"
602+
assert kwargs["fmm_backend"] == "fmmlib"
604603

605604
if self.comm.Get_rank() == 0:
606-
super().__init__(
607-
*args,
608-
_use_target_specific_qbx=_use_target_specific_qbx,
609-
fmm_backend=fmm_backend,
610-
**kwargs)
605+
super().__init__(*args, **kwargs)
611606
else:
612-
self._use_target_specific_qbx = _use_target_specific_qbx
613-
self.fmm_backend = fmm_backend
607+
self.fmm_backend = "fmmlib"
608+
self._use_target_specific_qbx = kwargs.get(
609+
"_use_target_specific_qbx", None)
614610
self.qbx_order = kwargs.get("qbx_order", None)
615611
self.fmm_level_to_order = kwargs.get("fmm_level_to_order", None)
616612
self.expansion_factory = kwargs.get("expansion_factory", None)
@@ -727,6 +723,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
727723

728724
user_source_ids = None
729725
if self.comm.Get_rank() == 0:
726+
assert global_geo_data_device is not None
730727
user_source_ids = global_geo_data_device.tree().user_source_ids
731728

732729
kernel_extra_kwargs, source_extra_kwargs = (
@@ -767,6 +764,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
767764
self._use_target_specific_qbx)
768765

769766
if self.comm.Get_rank() == 0:
767+
assert global_geo_data_device is not None
770768
from pytential.qbx.geometry import target_state
771769
if actx.to_numpy(actx.np.any(
772770
actx.thaw(global_geo_data_device.user_target_to_center())
@@ -784,11 +782,12 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
784782
# }}}
785783

786784
# Execute global QBX.
787-
timing_data = {}
785+
timing_data: Dict[str, Any] = {}
788786
all_potentials_on_every_target = drive_dfmm(
789787
self.comm, flat_strengths, wrangler, timing_data)
790788

791789
if self.comm.Get_rank() == 0:
790+
assert global_geo_data_device is not None
792791
results = []
793792

794793
for o in insn.outputs:
@@ -830,6 +829,7 @@ def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
830829

831830
# {{{ Distribute source weights
832831

832+
template_ary = None
833833
if current_rank == 0:
834834
template_ary = src_weight_vecs[0]
835835

0 commit comments

Comments
 (0)