2626from arraycontext import flatten , unflatten
2727from meshmode .dof_array import DOFArray
2828from meshmode .discretization import Discretization
29+ from arraycontext import ArrayContext
2930
3031from sumpy .fmm import (SumpyTimingFuture ,
3132 SumpyTreeIndependentDataForWrangler , SumpyExpansionWrangler )
3233from sumpy .expansion import DefaultExpansionFactory
3334
3435from functools import partial
3536from collections import defaultdict
37+ from typing import Optional , Mapping , Union , Callable
3638
3739
3840__doc__ = """
@@ -107,22 +109,43 @@ def evaluate_kernel_arguments(actx, evaluate, kernel_arguments, flat=True):
107109
108110class PointPotentialSource (_SumpyP2PMixin , PotentialSource ):
109111 """
110- .. attribute :: nodes
112+ .. method :: nodes
111113
112114 An :class:`pyopencl.array.Array` of shape ``[ambient_dim, ndofs]``.
113115
114116 .. attribute:: ndofs
115117
116- .. attribute:: fmm_order
117-
118118 .. automethod:: cost_model_compute_potential_insn
119119 .. automethod:: exec_compute_potential_insn
120120 """
121121
122- def __init__ (self , nodes , * , fmm_order = False , fmm_level_to_order = None ,
123- expansion_factory = default_expansion_factory ,
124- tree_build_kwargs = None ,
125- trav_build_kwargs = None ):
122+ def __init__ (self , nodes , * ,
123+ fmm_order : Optional [int ] = False ,
124+ fmm_level_to_order : Optional [Union [bool , Callable [..., int ]]] = None ,
125+ expansion_factory : Optional [DefaultExpansionFactory ] \
126+ = default_expansion_factory ,
127+ tree_build_kwargs : Optional [Mapping ] = None ,
128+ trav_build_kwargs : Optional [Mapping ] = None ,
129+ setup_actx : Optional [ArrayContext ] = None
130+ ):
131+ """
132+ :arg nodes: The point potential source given as a
133+ :class:`pyopencl.array.Array`
134+ :arg fmm_order: The order of the FMM for all levels if *fmm_order* is not
135+ *False*. Mutually exclusive with argument *fmm_level_to_order*.
136+ If both arguments are not given a direct point-to-point calculation
137+ is used.
138+ :arg fmm_level_to_order: An optional callable that returns the FMM order
139+ to use for a given level. Mutually exclusive with *fmm_order* argument.
140+ :arg expansion_factory: An expansion factory to get the expansion objects
141+ when an FMM is used.
142+ :arg tree_build_kwargs: Keyword arguments to be passed when building the
143+ tree for an FMM.
144+ :arg trav_build_kwargs: Keyword arguments to be passed when building a
145+ traversal for an FMM.
146+ :arg setup_actx: An array context to be used when building a tree
147+ for an FMM.
148+ """
126149
127150 if fmm_order is not False and fmm_level_to_order is not None :
128151 raise TypeError ("may not specify both fmm_order and fmm_level_to_order" )
@@ -137,6 +160,7 @@ def fmm_level_to_order(kernel, kernel_args, tree, level): # noqa pylint:disable
137160 self .expansion_factory = expansion_factory
138161 self .tree_build_kwargs = tree_build_kwargs if tree_build_kwargs else {}
139162 self .trav_build_kwargs = trav_build_kwargs if trav_build_kwargs else {}
163+ self ._setup_actx = setup_actx
140164 self ._nodes = nodes
141165
142166 @property
@@ -159,10 +183,13 @@ def ndofs(self):
159183 for coord_ary in self ._nodes :
160184 return coord_ary .shape [0 ]
161185
162- def copy (self , nodes = None , fmm_order = None , fmm_level_to_order = None ,
163- expansion_factory = None , tree_build_kwargs = None , trav_build_kwargs = None ):
186+ def copy (self , * , nodes = None , fmm_order = None , fmm_level_to_order = None ,
187+ expansion_factory = None , tree_build_kwargs = None , trav_build_kwargs = None ,
188+ setup_actx = None ):
164189 if nodes is None :
165190 nodes = self ._nodes
191+ if setup_actx is None :
192+ setup_actx = self ._setup_actx
166193 if fmm_level_to_order is None and fmm_order is None :
167194 fmm_level_to_order = self .fmm_level_to_order
168195 if expansion_factory is None :
@@ -179,6 +206,7 @@ def copy(self, nodes=None, fmm_order=None, fmm_level_to_order=None,
179206 expansion_factory = expansion_factory ,
180207 tree_build_kwargs = tree_build_kwargs ,
181208 trav_build_kwargs = trav_build_kwargs ,
209+ setup_actx = setup_actx ,
182210 )
183211
184212 @property
@@ -194,6 +222,7 @@ def ambient_dim(self):
194222
195223 def op_group_features (self , expr ):
196224 from pytential .utils import sort_arrays_together
225+ from sumpy .kernel import TargetTransformationRemover
197226 # since IntGs with the same source kernels and densities calculations
198227 # for P2E and E2E are the same and only differs in E2P depending on the
199228 # target kernel, we group all IntGs with same source kernels and densities.
@@ -212,14 +241,33 @@ def cost_model_compute_potential_insn(self, actx, insn, bound_expr,
212241 raise NotImplementedError
213242
214243 @memoize_method
215- def _get_exec_insn_func (self , actx , source_kernels , target_kernels ):
244+ def _get_tree (self , target_discr ):
245+ """Builds a tree for targets given by *target_discr* and caches the
246+ result. Needed only when an FMM is used.
247+ """
248+ from boxtree import TreeBuilder
249+ from boxtree .traversal import FMMTraversalBuilder
250+
251+ actx = self ._setup_actx
252+ sources = self ._nodes
253+ targets = flatten (target_discr .nodes (), actx , leaf_class = DOFArray )
254+ tree_build = TreeBuilder (actx .context )
255+ trav_build = FMMTraversalBuilder (actx .context ,
256+ ** self .trav_build_kwargs )
257+ tree , _ = tree_build (actx .queue , sources , targets = targets ,
258+ ** self .tree_build_kwargs )
259+ trav , _ = trav_build (actx .queue , tree )
260+ return tree , trav
216261
262+ @memoize_method
263+ def _get_exec_insn_func (self , source_kernels , target_kernels , target_discr ):
217264 if self .fmm_level_to_order is False :
218- p2p = self .get_p2p (actx , source_kernels = source_kernels ,
265+ sources = self ._nodes
266+ targets = flatten (target_discr .nodes (), self ._setup_actx , leaf_class = DOFArray )
267+ def exec_insn (actx , strengths , kernel_args , dtype , return_timing_data ):
268+ p2p = self .get_p2p (actx , source_kernels = source_kernels ,
219269 target_kernels = target_kernels )
220270
221- def exec_insn (sources , targets , strengths , kernel_args ,
222- dtype , return_timing_data ):
223271 evt , output = p2p (actx .queue ,
224272 targets = targets ,
225273 sources = sources ,
@@ -232,14 +280,8 @@ def exec_insn(sources, targets, strengths, kernel_args,
232280 timing_data = None
233281 return timing_data , output
234282 else :
235- from boxtree import TreeBuilder
236- from boxtree .traversal import FMMTraversalBuilder
237283 from boxtree .fmm import drive_fmm
238284
239- tree_build = TreeBuilder (actx .context )
240- trav_build = FMMTraversalBuilder (actx .context ,
241- ** self .trav_build_kwargs )
242-
243285 kernel = target_kernels [0 ].get_base_kernel ()
244286 local_expansion_factory = \
245287 self .expansion_factory .get_local_expansion_class (kernel )
@@ -248,11 +290,9 @@ def exec_insn(sources, targets, strengths, kernel_args,
248290 self .expansion_factory .get_multipole_expansion_class (kernel )
249291 mpole_expansion_factory = partial (mpole_expansion_factory , kernel )
250292
251- def exec_insn (sources , targets , strengths , kernel_args ,
252- dtype , return_timing_data ):
253- tree , _ = tree_build (actx .queue , sources , targets = targets ,
254- ** self .tree_build_kwargs )
255- trav , _ = trav_build (actx .queue , tree )
293+ tree , trav = self ._get_tree (target_discr )
294+
295+ def exec_insn (actx , strengths , kernel_args , dtype , return_timing_data ):
256296 tree_indep = SumpyTreeIndependentDataForWrangler (
257297 actx .context ,
258298 mpole_expansion_factory ,
@@ -263,7 +303,6 @@ def exec_insn(sources, targets, strengths, kernel_args,
263303 wrangler = SumpyExpansionWrangler (tree_indep , trav , dtype ,
264304 fmm_level_to_order = self .fmm_level_to_order ,
265305 kernel_extra_kwargs = kernel_args )
266-
267306 timing_data = {} if return_timing_data else None
268307 output = drive_fmm (wrangler , strengths , timing_data = timing_data )
269308 return timing_data , output
@@ -283,12 +322,6 @@ def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
283322 else :
284323 dtype = self .real_dtype
285324
286- exec_insn = self ._get_exec_insn_func (
287- actx = actx ,
288- source_kernels = insn .source_kernels ,
289- target_kernels = insn .target_kernels ,
290- )
291-
292325 outputs_grouped_by_target = defaultdict (list )
293326 for o in insn .outputs :
294327 outputs_grouped_by_target [o .target_name ].append (o )
@@ -299,11 +332,14 @@ def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
299332 target_discr = bound_expr .places .get_discretization (
300333 target_name .geometry , target_name .discr_stage )
301334
302- sources = self ._nodes
303- targets = flatten (target_discr .nodes (), actx , leaf_class = DOFArray )
335+ exec_insn = self ._get_exec_insn_func (
336+ source_kernels = insn .source_kernels ,
337+ target_kernels = insn .target_kernels ,
338+ target_discr = target_discr ,
339+ )
304340
305341 timing_data , output_for_each_kernel = \
306- exec_insn (sources , targets , strengths , kernel_args ,
342+ exec_insn (actx , strengths , kernel_args ,
307343 dtype , return_timing_data )
308344 timing_data_arr .append (timing_data )
309345
0 commit comments