Skip to content

Commit b8ed9e1

Browse files
committed
docs and memoize
1 parent fb3e69f commit b8ed9e1

File tree

1 file changed

+70
-34
lines changed

1 file changed

+70
-34
lines changed

pytential/source.py

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
from arraycontext import flatten, unflatten
2727
from meshmode.dof_array import DOFArray
2828
from meshmode.discretization import Discretization
29+
from arraycontext import ArrayContext
2930

3031
from sumpy.fmm import (SumpyTimingFuture,
3132
SumpyTreeIndependentDataForWrangler, SumpyExpansionWrangler)
3233
from sumpy.expansion import DefaultExpansionFactory
3334

3435
from functools import partial
3536
from collections import defaultdict
37+
from typing import Optional, Mapping, Union, Callable
3638

3739

3840
__doc__ = """
@@ -107,22 +109,43 @@ def evaluate_kernel_arguments(actx, evaluate, kernel_arguments, flat=True):
107109

108110
class PointPotentialSource(_SumpyP2PMixin, PotentialSource):
109111
"""
110-
.. attribute:: nodes
112+
.. method:: nodes
111113
112114
An :class:`pyopencl.array.Array` of shape ``[ambient_dim, ndofs]``.
113115
114116
.. attribute:: ndofs
115117
116-
.. attribute:: fmm_order
117-
118118
.. automethod:: cost_model_compute_potential_insn
119119
.. automethod:: exec_compute_potential_insn
120120
"""
121121

122-
def __init__(self, nodes, *, fmm_order=False, fmm_level_to_order=None,
123-
expansion_factory=default_expansion_factory,
124-
tree_build_kwargs=None,
125-
trav_build_kwargs=None):
122+
def __init__(self, nodes, *,
123+
fmm_order: Optional[int] = False,
124+
fmm_level_to_order: Optional[Union[bool, Callable[..., int]]] = None,
125+
expansion_factory: Optional[DefaultExpansionFactory] \
126+
= default_expansion_factory,
127+
tree_build_kwargs: Optional[Mapping] = None,
128+
trav_build_kwargs: Optional[Mapping] = None,
129+
setup_actx: Optional[ArrayContext] = None
130+
):
131+
"""
132+
:arg nodes: The point potential source given as a
133+
:class:`pyopencl.array.Array`
134+
:arg fmm_order: The order of the FMM for all levels if *fmm_order* is not
135+
*False*. Mutually exclusive with argument *fmm_level_to_order*.
136+
If both arguments are not given a direct point-to-point calculation
137+
is used.
138+
:arg fmm_level_to_order: An optional callable that returns the FMM order
139+
to use for a given level. Mutually exclusive with *fmm_order* argument.
140+
:arg expansion_factory: An expansion factory to get the expansion objects
141+
when an FMM is used.
142+
:arg tree_build_kwargs: Keyword arguments to be passed when building the
143+
tree for an FMM.
144+
:arg trav_build_kwargs: Keyword arguments to be passed when building a
145+
traversal for an FMM.
146+
:arg setup_actx: An array context to be used when building a tree
147+
for an FMM.
148+
"""
126149

127150
if fmm_order is not False and fmm_level_to_order is not None:
128151
raise TypeError("may not specify both fmm_order and fmm_level_to_order")
@@ -137,6 +160,7 @@ def fmm_level_to_order(kernel, kernel_args, tree, level): # noqa pylint:disable
137160
self.expansion_factory = expansion_factory
138161
self.tree_build_kwargs = tree_build_kwargs if tree_build_kwargs else {}
139162
self.trav_build_kwargs = trav_build_kwargs if trav_build_kwargs else {}
163+
self._setup_actx = setup_actx
140164
self._nodes = nodes
141165

142166
@property
@@ -159,10 +183,13 @@ def ndofs(self):
159183
for coord_ary in self._nodes:
160184
return coord_ary.shape[0]
161185

162-
def copy(self, nodes=None, fmm_order=None, fmm_level_to_order=None,
163-
expansion_factory=None, tree_build_kwargs=None, trav_build_kwargs=None):
186+
def copy(self, *, nodes=None, fmm_order=None, fmm_level_to_order=None,
187+
expansion_factory=None, tree_build_kwargs=None, trav_build_kwargs=None,
188+
setup_actx=None):
164189
if nodes is None:
165190
nodes = self._nodes
191+
if setup_actx is None:
192+
setup_actx = self._setup_actx
166193
if fmm_level_to_order is None and fmm_order is None:
167194
fmm_level_to_order = self.fmm_level_to_order
168195
if expansion_factory is None:
@@ -179,6 +206,7 @@ def copy(self, nodes=None, fmm_order=None, fmm_level_to_order=None,
179206
expansion_factory=expansion_factory,
180207
tree_build_kwargs=tree_build_kwargs,
181208
trav_build_kwargs=trav_build_kwargs,
209+
setup_actx=setup_actx,
182210
)
183211

184212
@property
@@ -194,6 +222,7 @@ def ambient_dim(self):
194222

195223
def op_group_features(self, expr):
196224
from pytential.utils import sort_arrays_together
225+
from sumpy.kernel import TargetTransformationRemover
197226
# since IntGs with the same source kernels and densities calculations
198227
# for P2E and E2E are the same and only differs in E2P depending on the
199228
# target kernel, we group all IntGs with same source kernels and densities.
@@ -212,14 +241,33 @@ def cost_model_compute_potential_insn(self, actx, insn, bound_expr,
212241
raise NotImplementedError
213242

214243
@memoize_method
215-
def _get_exec_insn_func(self, actx, source_kernels, target_kernels):
244+
def _get_tree(self, target_discr):
245+
"""Builds a tree for targets given by *target_discr* and caches the
246+
result. Needed only when an FMM is used.
247+
"""
248+
from boxtree import TreeBuilder
249+
from boxtree.traversal import FMMTraversalBuilder
250+
251+
actx = self._setup_actx
252+
sources = self._nodes
253+
targets = flatten(target_discr.nodes(), actx, leaf_class=DOFArray)
254+
tree_build = TreeBuilder(actx.context)
255+
trav_build = FMMTraversalBuilder(actx.context,
256+
**self.trav_build_kwargs)
257+
tree, _ = tree_build(actx.queue, sources, targets=targets,
258+
**self.tree_build_kwargs)
259+
trav, _ = trav_build(actx.queue, tree)
260+
return tree, trav
216261

262+
@memoize_method
263+
def _get_exec_insn_func(self, source_kernels, target_kernels, target_discr):
217264
if self.fmm_level_to_order is False:
218-
p2p = self.get_p2p(actx, source_kernels=source_kernels,
265+
sources = self._nodes
266+
targets = flatten(target_discr.nodes(), self._setup_actx, leaf_class=DOFArray)
267+
def exec_insn(actx, strengths, kernel_args, dtype, return_timing_data):
268+
p2p = self.get_p2p(actx, source_kernels=source_kernels,
219269
target_kernels=target_kernels)
220270

221-
def exec_insn(sources, targets, strengths, kernel_args,
222-
dtype, return_timing_data):
223271
evt, output = p2p(actx.queue,
224272
targets=targets,
225273
sources=sources,
@@ -232,14 +280,8 @@ def exec_insn(sources, targets, strengths, kernel_args,
232280
timing_data = None
233281
return timing_data, output
234282
else:
235-
from boxtree import TreeBuilder
236-
from boxtree.traversal import FMMTraversalBuilder
237283
from boxtree.fmm import drive_fmm
238284

239-
tree_build = TreeBuilder(actx.context)
240-
trav_build = FMMTraversalBuilder(actx.context,
241-
**self.trav_build_kwargs)
242-
243285
kernel = target_kernels[0].get_base_kernel()
244286
local_expansion_factory = \
245287
self.expansion_factory.get_local_expansion_class(kernel)
@@ -248,11 +290,9 @@ def exec_insn(sources, targets, strengths, kernel_args,
248290
self.expansion_factory.get_multipole_expansion_class(kernel)
249291
mpole_expansion_factory = partial(mpole_expansion_factory, kernel)
250292

251-
def exec_insn(sources, targets, strengths, kernel_args,
252-
dtype, return_timing_data):
253-
tree, _ = tree_build(actx.queue, sources, targets=targets,
254-
**self.tree_build_kwargs)
255-
trav, _ = trav_build(actx.queue, tree)
293+
tree, trav = self._get_tree(target_discr)
294+
295+
def exec_insn(actx, strengths, kernel_args, dtype, return_timing_data):
256296
tree_indep = SumpyTreeIndependentDataForWrangler(
257297
actx.context,
258298
mpole_expansion_factory,
@@ -263,7 +303,6 @@ def exec_insn(sources, targets, strengths, kernel_args,
263303
wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
264304
fmm_level_to_order=self.fmm_level_to_order,
265305
kernel_extra_kwargs=kernel_args)
266-
267306
timing_data = {} if return_timing_data else None
268307
output = drive_fmm(wrangler, strengths, timing_data=timing_data)
269308
return timing_data, output
@@ -283,12 +322,6 @@ def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
283322
else:
284323
dtype = self.real_dtype
285324

286-
exec_insn = self._get_exec_insn_func(
287-
actx=actx,
288-
source_kernels=insn.source_kernels,
289-
target_kernels=insn.target_kernels,
290-
)
291-
292325
outputs_grouped_by_target = defaultdict(list)
293326
for o in insn.outputs:
294327
outputs_grouped_by_target[o.target_name].append(o)
@@ -299,11 +332,14 @@ def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
299332
target_discr = bound_expr.places.get_discretization(
300333
target_name.geometry, target_name.discr_stage)
301334

302-
sources = self._nodes
303-
targets = flatten(target_discr.nodes(), actx, leaf_class=DOFArray)
335+
exec_insn = self._get_exec_insn_func(
336+
source_kernels=insn.source_kernels,
337+
target_kernels=insn.target_kernels,
338+
target_discr=target_discr,
339+
)
304340

305341
timing_data, output_for_each_kernel = \
306-
exec_insn(sources, targets, strengths, kernel_args,
342+
exec_insn(actx, strengths, kernel_args,
307343
dtype, return_timing_data)
308344
timing_data_arr.append(timing_data)
309345

0 commit comments

Comments
 (0)