@@ -409,7 +409,7 @@ def _alias_global_temporaries(t_unit):
409409 from loopy .kernel import KernelState
410410 from loopy .schedule import (RunInstruction , EnterLoop , LeaveLoop ,
411411 CallKernel , ReturnFromKernel , Barrier )
412- from loopy .schedule .tools import get_nearest_return_from_kernel
412+ from loopy .schedule .tools import get_return_from_kernel_mapping
413413 from pytools import UniqueNameGenerator
414414 from collections import defaultdict
415415
@@ -420,7 +420,7 @@ def _alias_global_temporaries(t_unit):
420420 if tv .address_space == AddressSpace .GLOBAL )
421421 temp_to_live_interval_start = {}
422422 temp_to_live_interval_end = {}
423- return_from_kernel_idxs = get_nearest_return_from_kernel (kernel )
423+ return_from_kernel_idxs = get_return_from_kernel_mapping (kernel )
424424
425425 for sched_idx , sched_item in enumerate (kernel .linearization ):
426426 if isinstance (sched_item , RunInstruction ):
@@ -508,10 +508,13 @@ def _make_global_temporaries_private(t_unit):
508508 for read_insn in read_insns )):
509509 if len ({knl .insn_inames (read_insn ) for read_insn in read_insns }) == 1 :
510510 knl = lp .assignment_to_subst (knl , tv .name )
511- knl = precompute_for_single_kernel (
512- knl , t_unit .callables_table , f"{ tv .name } _subst" ,
513- sweep_inames = (),
514- temporary_address_space = lp .AddressSpace .PRIVATE )
511+ try :
512+ knl = precompute_for_single_kernel (
513+ knl , t_unit .callables_table , f"{ tv .name } _subst" ,
514+ sweep_inames = (),
515+ temporary_address_space = lp .AddressSpace .PRIVATE )
516+ except RuntimeError :
517+ pass
515518
516519 return t_unit .with_kernel (knl )
517520
@@ -578,7 +581,14 @@ class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase):
578581 def transform_loopy_program (self , t_unit ):
579582 import loopy as lp
580583
581- t_unit = _make_global_temporaries_private (t_unit )
584+ # if len(t_unit.default_entrypoint.instructions) > 50:
585+ # import pudb; pu.db
586+ # 1/0
587+ # with open("nozzle.knl", "w") as f:
588+ # f.write(str(t_unit))
589+ # 1/0
590+
591+ # t_unit = _make_global_temporaries_private(t_unit)
582592 t_unit = _single_grid_work_group_transform (t_unit , self .queue .device )
583593 t_unit = lp .set_options (t_unit , "insert_gbarriers" )
584594 t_unit = lp .linearize (lp .preprocess_kernel (t_unit ))
0 commit comments