Skip to content

Commit 0c55627

Browse files
committed
advanced indexing, disable demoting to private temps to account for advanced indexing exprs
1 parent bd964da commit 0c55627

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

meshmode/array_context.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def _alias_global_temporaries(t_unit):
404404
from loopy.kernel import KernelState
405405
from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop,
406406
CallKernel, ReturnFromKernel, Barrier)
407-
from loopy.schedule.tools import get_nearest_return_from_kernel
407+
from loopy.schedule.tools import get_return_from_kernel_mapping
408408
from pytools import UniqueNameGenerator
409409
from collections import defaultdict
410410

@@ -415,7 +415,7 @@ def _alias_global_temporaries(t_unit):
415415
if tv.address_space == AddressSpace.GLOBAL)
416416
temp_to_live_interval_start = {}
417417
temp_to_live_interval_end = {}
418-
return_from_kernel_idxs = get_nearest_return_from_kernel(kernel)
418+
return_from_kernel_idxs = get_return_from_kernel_mapping(kernel)
419419

420420
for sched_idx, sched_item in enumerate(kernel.linearization):
421421
if isinstance(sched_item, RunInstruction):
@@ -503,10 +503,13 @@ def _make_global_temporaries_private(t_unit):
503503
for read_insn in read_insns)):
504504
if len({knl.insn_inames(read_insn) for read_insn in read_insns}) == 1:
505505
knl = lp.assignment_to_subst(knl, tv.name)
506-
knl = precompute_for_single_kernel(
507-
knl, t_unit.callables_table, f"{tv.name}_subst",
508-
sweep_inames=(),
509-
temporary_address_space=lp.AddressSpace.PRIVATE)
506+
try:
507+
knl = precompute_for_single_kernel(
508+
knl, t_unit.callables_table, f"{tv.name}_subst",
509+
sweep_inames=(),
510+
temporary_address_space=lp.AddressSpace.PRIVATE)
511+
except RuntimeError:
512+
pass
510513

511514
return t_unit.with_kernel(knl)
512515

@@ -573,7 +576,14 @@ class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase):
573576
def transform_loopy_program(self, t_unit):
574577
import loopy as lp
575578

576-
t_unit = _make_global_temporaries_private(t_unit)
579+
# if len(t_unit.default_entrypoint.instructions) > 50:
580+
# import pudb; pu.db
581+
# 1/0
582+
# with open("nozzle.knl", "w") as f:
583+
# f.write(str(t_unit))
584+
# 1/0
585+
586+
# t_unit = _make_global_temporaries_private(t_unit)
577587
t_unit = _single_grid_work_group_transform(t_unit, self.queue.device)
578588
t_unit = lp.set_options(t_unit, "insert_gbarriers")
579589
t_unit = lp.linearize(lp.preprocess_kernel(t_unit))

meshmode/discretization/connection/direct.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,14 +449,28 @@ def batch_pick_knl():
449449
)["result"]
450450

451451
else:
452-
batch_result = actx.call_loopy(
453-
batch_pick_knl(),
454-
pick_list=point_pick_indices,
455-
ary=ary[batch.from_group_index],
456-
from_element_indices=batch._global_from_element_indices(
457-
actx, self.to_discr.groups[i_tgrp]),
458-
n_to_nodes=self.to_discr.groups[i_tgrp].nunit_dofs
459-
)["result"]
452+
if actx.permits_advanced_indexing:
453+
from_vec = ary[batch.from_group_index]
454+
from_element_indices = actx.thaw(
455+
batch._global_from_element_indices(
456+
actx, self.to_discr.groups[i_tgrp])
457+
).reshape((-1, 1))
458+
pick_list = actx.thaw(point_pick_indices)
459+
batch_result = actx.np.where(
460+
actx.np.not_equal(from_element_indices, -1),
461+
from_vec[from_element_indices, pick_list],
462+
0)
463+
assert batch_result.shape == (from_element_indices.size,
464+
pick_list.size)
465+
else:
466+
batch_result = actx.call_loopy(
467+
batch_pick_knl(),
468+
pick_list=point_pick_indices,
469+
ary=ary[batch.from_group_index],
470+
from_element_indices=batch._global_from_element_indices(
471+
actx, self.to_discr.groups[i_tgrp]),
472+
n_to_nodes=self.to_discr.groups[i_tgrp].nunit_dofs
473+
)["result"]
460474

461475
batched_data.append(batch_result)
462476

0 commit comments

Comments
 (0)