@@ -404,7 +404,7 @@ def _alias_global_temporaries(t_unit):
404404 from loopy .kernel import KernelState
405405 from loopy .schedule import (RunInstruction , EnterLoop , LeaveLoop ,
406406 CallKernel , ReturnFromKernel , Barrier )
407- from loopy .schedule .tools import get_nearest_return_from_kernel
407+ from loopy .schedule .tools import get_return_from_kernel_mapping
408408 from pytools import UniqueNameGenerator
409409 from collections import defaultdict
410410
@@ -415,7 +415,7 @@ def _alias_global_temporaries(t_unit):
415415 if tv .address_space == AddressSpace .GLOBAL )
416416 temp_to_live_interval_start = {}
417417 temp_to_live_interval_end = {}
418- return_from_kernel_idxs = get_nearest_return_from_kernel (kernel )
418+ return_from_kernel_idxs = get_return_from_kernel_mapping (kernel )
419419
420420 for sched_idx , sched_item in enumerate (kernel .linearization ):
421421 if isinstance (sched_item , RunInstruction ):
@@ -503,10 +503,13 @@ def _make_global_temporaries_private(t_unit):
503503 for read_insn in read_insns )):
504504 if len ({knl .insn_inames (read_insn ) for read_insn in read_insns }) == 1 :
505505 knl = lp .assignment_to_subst (knl , tv .name )
506- knl = precompute_for_single_kernel (
507- knl , t_unit .callables_table , f"{ tv .name } _subst" ,
508- sweep_inames = (),
509- temporary_address_space = lp .AddressSpace .PRIVATE )
506+ try :
507+ knl = precompute_for_single_kernel (
508+ knl , t_unit .callables_table , f"{ tv .name } _subst" ,
509+ sweep_inames = (),
510+ temporary_address_space = lp .AddressSpace .PRIVATE )
511+ except RuntimeError :
512+ pass
510513
511514 return t_unit .with_kernel (knl )
512515
@@ -573,7 +576,14 @@ class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase):
573576 def transform_loopy_program (self , t_unit ):
574577 import loopy as lp
575578
576- t_unit = _make_global_temporaries_private (t_unit )
579+ # if len(t_unit.default_entrypoint.instructions) > 50:
580+ # import pudb; pu.db
581+ # 1/0
582+ # with open("nozzle.knl", "w") as f:
583+ # f.write(str(t_unit))
584+ # 1/0
585+
586+ # t_unit = _make_global_temporaries_private(t_unit)
577587 t_unit = _single_grid_work_group_transform (t_unit , self .queue .device )
578588 t_unit = lp .set_options (t_unit , "insert_gbarriers" )
579589 t_unit = lp .linearize (lp .preprocess_kernel (t_unit ))
0 commit comments