2121"""
2222
2323from abc import ABC , abstractmethod
24+ from pytools import memoize_method
2425
2526import numpy as np
2627import loopy as lp
28+ from loopy .kernel .data import LocalInameTag
29+ import pymbolic .primitives as prim
2730
2831from sumpy .tools import KernelCacheMixin , gather_loopy_arguments
2932from loopy .version import MOST_RECENT_LANGUAGE_VERSION
@@ -70,7 +73,7 @@ def __init__(self, ctx, expansion, kernels,
7073
7174 self .ctx = ctx
7275 self .expansion = expansion
73- self .kernels = kernels
76+ self .kernels = tuple ( kernels )
7477 self .name = name or self .default_name
7578 self .device = device
7679
@@ -81,15 +84,18 @@ def __init__(self, ctx, expansion, kernels,
8184 def default_name (self ):
8285 pass
8386
87+ @memoize_method
88+ def get_cached_loopy_knl_and_optimizations (self ):
89+ return self .expansion .get_loopy_evaluator (self .kernels )
90+
8491 def get_cache_key (self ):
8592 return (type (self ).__name__ , self .expansion , tuple (self .kernels ))
8693
8794 def add_loopy_eval_callable (
8895 self , loopy_knl : lp .TranslationUnit ) -> lp .TranslationUnit :
89- inner_knl = self .expansion . get_loopy_evaluator ( self . kernels )
96+ inner_knl , _ = self .get_cached_loopy_knl_and_optimizations ( )
9097 loopy_knl = lp .merge ([loopy_knl , inner_knl ])
9198 loopy_knl = lp .inline_callable_kernel (loopy_knl , "e2p" )
92- loopy_knl = lp .remove_unused_inames (loopy_knl )
9399 for kernel in self .kernels :
94100 loopy_knl = kernel .prepare_loopy_kernel (loopy_knl )
95101 loopy_knl = lp .tag_array_axes (loopy_knl , "targets" , "sep,C" )
@@ -117,33 +123,41 @@ class E2PFromSingleBox(E2PBase):
117123 def default_name (self ):
118124 return "e2p_from_single_box"
119125
120- def get_kernel (self ):
126+ def get_kernel (self , max_ntargets_in_one_box ):
121127 ncoeffs = len (self .expansion )
122128 loopy_args = self .get_loopy_args ()
129+ max_work_items = min (32 , max (ncoeffs , max_ntargets_in_one_box ))
123130
124131 loopy_knl = lp .make_kernel (
125132 [
126133 "{[itgt_box]: 0<=itgt_box<ntgt_boxes}" ,
127- "{[itgt,idim]: itgt_start<=itgt<itgt_end and 0<=idim<dim}" ,
134+ "{[idim]: 0<=idim<dim}" ,
135+ "{[itgt_offset]: 0<=itgt_offset<max_ntargets_in_one_box}" ,
128136 "{[icoeff]: 0<=icoeff<ncoeffs}" ,
129137 "{[iknl]: 0<=iknl<nresults}" ,
138+ "{[dummy]: 0<=dummy<max_work_items}" ,
130139 ],
131140 self .get_kernel_scaling_assignment ()
132141 + ["""
133142 for itgt_box
134- <> tgt_ibox = target_boxes[itgt_box]
135- <> itgt_start = box_target_starts[tgt_ibox]
136- <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
143+ <> tgt_ibox = target_boxes[itgt_box] {id=fetch_init0}
144+ <> itgt_start = box_target_starts[tgt_ibox] {id=fetch_init1}
145+ <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox] \
146+ {id=fetch_init2}
137147
138148 <> center[idim] = centers[idim, tgt_ibox] {id=fetch_center}
139149
140150 <> coeffs[icoeff] = \
141151 src_expansions[tgt_ibox - src_base_ibox, icoeff] \
142152 {id=fetch_coeffs}
143153
144- for itgt
145- <> tgt[idim] = targets[idim, itgt] {id=fetch_tgt,dup=idim}
146- <> result_temp[iknl] = 0 {id=init_result,dup=iknl}
154+ for itgt_offset
155+ <> itgt = itgt_start + itgt_offset
156+ <> run_itgt = itgt<itgt_end
157+ <> tgt[idim] = targets[idim, itgt] {id=fetch_tgt, \
158+ dup=idim,if=run_itgt}
159+ <> result_temp[iknl] = 0 {id=init_result,dup=iknl, \
160+ if=run_itgt}
147161 [iknl]: result_temp[iknl] = e2p(
148162 [iknl]: result_temp[iknl],
149163 [icoeff]: coeffs[icoeff],
@@ -155,9 +169,9 @@ def get_kernel(self):
155169 targets,
156170 """ + "," .join (arg .name for arg in loopy_args ) + """
157171 ) {dep=fetch_coeffs:fetch_center:init_result:fetch_tgt,\
158- id=update_result}
172+ id=update_result,if=run_itgt }
159173 result[iknl, itgt] = result_temp[iknl] * kernel_scaling \
160- {id=write_result,dep=update_result}
174+ {id=write_result,dep=update_result,if=run_itgt }
161175 end
162176 end
163177 """ ],
@@ -182,7 +196,9 @@ def get_kernel(self):
182196 silenced_warnings = "write_race(*_result)" ,
183197 default_offset = lp .auto ,
184198 fixed_parameters = {"dim" : self .dim , "nresults" : len (self .kernels ),
185- "ncoeffs" : ncoeffs },
199+ "ncoeffs" : ncoeffs ,
200+ "max_work_items" : max_work_items ,
201+ "max_ntargets_in_one_box" : max_ntargets_in_one_box },
186202 lang_version = MOST_RECENT_LANGUAGE_VERSION )
187203
188204 loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
@@ -191,13 +207,39 @@ def get_kernel(self):
191207
192208 return loopy_knl
193209
194- def get_optimized_kernel (self ):
195- # FIXME
196- knl = self .get_kernel ()
210+ def get_optimized_kernel (self , max_ntargets_in_one_box ):
211+ inner_knl , optimizations = self . get_cached_loopy_knl_and_optimizations ()
212+ knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box )
197213 knl = lp .tag_inames (knl , {"itgt_box" : "g.0" })
214+ knl = lp .split_iname (knl , "itgt_offset" , 32 , inner_tag = "l.0" )
215+ knl = lp .split_iname (knl , "icoeff" , 32 , inner_tag = "l.0" )
216+ knl = lp .add_inames_to_insn (knl , "dummy" ,
217+ "id:fetch_init* or id:fetch_center or id:kernel_scaling" )
198218 knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:kernel_scaling" )
219+ knl = lp .tag_inames (knl , {"dummy" : "l.0" })
220+ knl = lp .set_temporary_address_space (knl , "coeffs" , lp .AddressSpace .LOCAL )
199221 knl = lp .set_options (knl ,
200- enforce_variable_access_ordered = "no_check" )
222+ enforce_variable_access_ordered = "no_check" , write_code = False )
223+
224+ for transform in optimizations :
225+ knl = transform (knl )
226+
227+ # If there are inames tagged as local in the inner kernel
228+ # we need to remove the iname itgt_offset_inner from instructions
229+ # within those inames and also remove the predicate run_itgt
230+ # which depends on itgt_offset_inner
231+ tagged_inames = [iname .name for iname in
232+ knl .default_entrypoint .inames .values () if
233+ iname .name .startswith ("e2p_" ) and any (
234+ isinstance (tag , LocalInameTag ) for tag in iname .tags )]
235+ if tagged_inames :
236+ insn_ids = [insn .id for insn in knl .default_entrypoint .instructions
237+ if any (iname in tagged_inames for iname in insn .within_inames )]
238+ match = " or " .join (f"id:{ insn_id } " for insn_id in insn_ids )
239+ knl = lp .remove_inames_from_insn (knl ,
240+ frozenset (["itgt_offset_inner" ]), match )
241+ knl = lp .remove_predicates_from_insn (knl ,
242+ frozenset ([prim .Variable ("run_itgt" )]), match )
201243
202244 return knl
203245
@@ -210,7 +252,9 @@ def __call__(self, queue, **kwargs):
210252 :arg centers:
211253 :arg targets:
212254 """
213- knl = self .get_cached_optimized_kernel ()
255+ max_ntargets_in_one_box = kwargs .pop ("max_ntargets_in_one_box" )
256+ knl = self .get_cached_optimized_kernel (
257+ max_ntargets_in_one_box = max_ntargets_in_one_box )
214258
215259 centers = kwargs .pop ("centers" )
216260 # "1" may be passed for rscale, which won't have its type
@@ -229,42 +273,49 @@ class E2PFromCSR(E2PBase):
229273 def default_name (self ):
230274 return "e2p_from_csr"
231275
232- def get_kernel (self ):
276+ def get_kernel (self , max_ntargets_in_one_box ):
233277 ncoeffs = len (self .expansion )
234278 loopy_args = self .get_loopy_args ()
279+ max_work_items = min (32 , max (ncoeffs , max_ntargets_in_one_box ))
235280
236281 loopy_knl = lp .make_kernel (
237282 [
238283 "{[itgt_box]: 0<=itgt_box<ntgt_boxes}" ,
239- "{[itgt ]: itgt_start<=itgt<itgt_end }" ,
284+ "{[itgt_offset ]: 0<=itgt_offset<max_ntargets_in_one_box }" ,
240285 "{[isrc_box]: isrc_box_start<=isrc_box<isrc_box_end }" ,
241286 "{[idim]: 0<=idim<dim}" ,
242287 "{[icoeff]: 0<=icoeff<ncoeffs}" ,
243288 "{[iknl]: 0<=iknl<nresults}" ,
289+ "{[dummy]: 0<=dummy<max_work_items}" ,
244290 ],
245291 self .get_kernel_scaling_assignment ()
246292 + ["""
247293 for itgt_box
248- <> tgt_ibox = target_boxes[itgt_box]
249- <> itgt_start = box_target_starts[tgt_ibox]
250- <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
251-
252- for itgt
253- <> tgt[idim] = targets[idim,itgt] {id=fetch_tgt,dup=idim}
254-
255- <> isrc_box_start = source_box_starts[itgt_box]
256- <> isrc_box_end = source_box_starts[itgt_box+1]
257-
258- <> result_temp[iknl] = 0 {id=init_result,dup=iknl}
259- for isrc_box
260- <> src_ibox = source_box_lists[isrc_box]
261- <> coeffs[icoeff] = \
294+ <> tgt_ibox = target_boxes[itgt_box] {id=init_box0}
295+ <> itgt_start = box_target_starts[tgt_ibox] {id=init_box1}
296+ <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox] \
297+ {id=init_box2}
298+ <> isrc_box_start = source_box_starts[itgt_box] {id=init_box3}
299+ <> isrc_box_end = source_box_starts[itgt_box+1] {id=init_box4}
300+
301+ <> result_temp[itgt_offset, iknl] = 0 \
302+ {id=init_result,dup=iknl}
303+ for isrc_box
304+ <> src_ibox = source_box_lists[isrc_box] {id=fetch_src_box}
305+ <> coeffs[icoeff] = \
262306 src_expansions[src_ibox - src_base_ibox, icoeff] \
263- {id=fetch_coeffs,dup=icoeff }
264- <> center[idim] = centers[idim, src_ibox] \
307+ {id=fetch_coeffs}
308+ <> center[idim] = centers[idim, src_ibox] \
265309 {dup=idim,id=fetch_center}
266- [iknl]: result_temp[iknl] = e2p(
267- [iknl]: result_temp[iknl],
310+
311+ for itgt_offset
312+ <> itgt = itgt_start + itgt_offset
313+ <> run_itgt = itgt<itgt_end
314+ <> tgt[idim] = targets[idim,itgt] \
315+ {id=fetch_tgt,dup=idim,if=run_itgt}
316+
317+ [iknl]: result_temp[itgt_offset, iknl] = e2p(
318+ [iknl]: result_temp[itgt_offset, iknl],
268319 [icoeff]: coeffs[icoeff],
269320 [idim]: center[idim],
270321 [idim]: tgt[idim],
@@ -274,11 +325,18 @@ def get_kernel(self):
274325 targets,
275326 """ + "," .join (arg .name for arg in loopy_args ) + """
276327 ) {id=update_result, \
277- dep=fetch_coeffs:fetch_center:fetch_tgt:init_result}
328+ dep=fetch_coeffs:fetch_center:fetch_tgt:init_result, \
329+ if=run_itgt}
278330 end
279- result[iknl, itgt] = result[iknl, itgt] + result_temp[iknl] \
280- * kernel_scaling \
281- {dep=update_result:init_result,id=write_result,dup=iknl}
331+ end
332+ for itgt_offset
333+ <> itgt2 = itgt_start + itgt_offset {id=init_itgt_for_write}
334+ <> run_itgt2 = itgt_start + itgt_offset < itgt_end \
335+ {id=init_cond_for_write}
336+ result[iknl, itgt2] = result[iknl, itgt2] + result_temp[ \
337+ itgt_offset, iknl] * kernel_scaling \
338+ {dep=update_result:init_result,id=write_result, \
339+ dup=iknl,if=run_itgt2}
282340 end
283341 end
284342 """ ],
@@ -306,28 +364,48 @@ def get_kernel(self):
306364 fixed_parameters = {
307365 "ncoeffs" : ncoeffs ,
308366 "dim" : self .dim ,
367+ "max_work_items" : max_work_items ,
368+ "max_ntargets_in_one_box" : max_ntargets_in_one_box ,
309369 "nresults" : len (self .kernels )},
310370 lang_version = MOST_RECENT_LANGUAGE_VERSION )
311371
312372 loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
313373 loopy_knl = lp .tag_inames (loopy_knl , "iknl*:unr" )
314- loopy_knl = lp .prioritize_loops (loopy_knl , "itgt_box,itgt, isrc_box" )
374+ loopy_knl = lp .prioritize_loops (loopy_knl , "itgt_box,isrc_box,itgt_offset " )
315375 loopy_knl = self .add_loopy_eval_callable (loopy_knl )
316376 loopy_knl = lp .tag_array_axes (loopy_knl , "targets" , "sep,C" )
317377
318378 return loopy_knl
319379
320- def get_optimized_kernel (self ):
321- # FIXME
322- knl = self .get_kernel ()
323- knl = lp .tag_inames (knl , {"itgt_box" : "g.0" })
380+ def get_optimized_kernel (self , max_ntargets_in_one_box ):
381+ _ , optimizations = self .get_cached_loopy_knl_and_optimizations ()
382+ knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box )
383+ knl = lp .tag_inames (knl , {"itgt_box" : "g.0" , "dummy" : "l.0" })
384+ knl = lp .unprivatize_temporaries_with_inames (knl ,
385+ "itgt_offset" , "result_temp" )
386+ knl = lp .split_iname (knl , "itgt_offset" , 32 , inner_tag = "l.0" )
387+ knl = lp .split_iname (knl , "icoeff" , 32 , inner_tag = "l.0" )
388+ knl = lp .privatize_temporaries_with_inames (knl ,
389+ "itgt_offset_outer" , "result_temp" )
390+ knl = lp .duplicate_inames (knl , "itgt_offset_outer" , "id:init_result" )
391+ knl = lp .duplicate_inames (knl , "itgt_offset_outer" ,
392+ "id:write_result or id:init_itgt_for_write or id:init_cond_for_write" )
393+ knl = lp .add_inames_to_insn (knl , "dummy" ,
394+ "id:init_box* or id:fetch_src_box or id:fetch_center "
395+ "or id:kernel_scaling" )
324396 knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:kernel_scaling" )
397+ knl = lp .add_inames_to_insn (knl , "itgt_offset_inner" , "id:fetch_init*" )
398+ knl = lp .set_temporary_address_space (knl , "coeffs" , lp .AddressSpace .LOCAL )
325399 knl = lp .set_options (knl ,
326- enforce_variable_access_ordered = "no_check" )
400+ enforce_variable_access_ordered = "no_check" , write_code = False )
401+ for transform in optimizations :
402+ knl = transform (knl )
327403 return knl
328404
329405 def __call__ (self , queue , ** kwargs ):
330- knl = self .get_cached_optimized_kernel ()
406+ max_ntargets_in_one_box = kwargs .pop ("max_ntargets_in_one_box" )
407+ knl = self .get_cached_optimized_kernel (
408+ max_ntargets_in_one_box = max_ntargets_in_one_box )
331409
332410 centers = kwargs .pop ("centers" )
333411 # "1" may be passed for rscale, which won't have its type
0 commit comments