Skip to content

Commit 804df7e

Browse files
committed
advanced indexing, disable demoting to private temps to account for advanced indexing exprs
1 parent e6fa34b commit 804df7e

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

meshmode/array_context.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _alias_global_temporaries(t_unit):
409409
from loopy.kernel import KernelState
410410
from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop,
411411
CallKernel, ReturnFromKernel, Barrier)
412-
from loopy.schedule.tools import get_nearest_return_from_kernel
412+
from loopy.schedule.tools import get_return_from_kernel_mapping
413413
from pytools import UniqueNameGenerator
414414
from collections import defaultdict
415415

@@ -420,7 +420,7 @@ def _alias_global_temporaries(t_unit):
420420
if tv.address_space == AddressSpace.GLOBAL)
421421
temp_to_live_interval_start = {}
422422
temp_to_live_interval_end = {}
423-
return_from_kernel_idxs = get_nearest_return_from_kernel(kernel)
423+
return_from_kernel_idxs = get_return_from_kernel_mapping(kernel)
424424

425425
for sched_idx, sched_item in enumerate(kernel.linearization):
426426
if isinstance(sched_item, RunInstruction):
@@ -508,10 +508,13 @@ def _make_global_temporaries_private(t_unit):
508508
for read_insn in read_insns)):
509509
if len({knl.insn_inames(read_insn) for read_insn in read_insns}) == 1:
510510
knl = lp.assignment_to_subst(knl, tv.name)
511-
knl = precompute_for_single_kernel(
512-
knl, t_unit.callables_table, f"{tv.name}_subst",
513-
sweep_inames=(),
514-
temporary_address_space=lp.AddressSpace.PRIVATE)
511+
try:
512+
knl = precompute_for_single_kernel(
513+
knl, t_unit.callables_table, f"{tv.name}_subst",
514+
sweep_inames=(),
515+
temporary_address_space=lp.AddressSpace.PRIVATE)
516+
except RuntimeError:
517+
pass
515518

516519
return t_unit.with_kernel(knl)
517520

@@ -578,7 +581,14 @@ class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase):
578581
def transform_loopy_program(self, t_unit):
579582
import loopy as lp
580583

581-
t_unit = _make_global_temporaries_private(t_unit)
584+
# if len(t_unit.default_entrypoint.instructions) > 50:
585+
# import pudb; pu.db
586+
# 1/0
587+
# with open("nozzle.knl", "w") as f:
588+
# f.write(str(t_unit))
589+
# 1/0
590+
591+
# t_unit = _make_global_temporaries_private(t_unit)
582592
t_unit = _single_grid_work_group_transform(t_unit, self.queue.device)
583593
t_unit = lp.set_options(t_unit, "insert_gbarriers")
584594
t_unit = lp.linearize(lp.preprocess_kernel(t_unit))

meshmode/discretization/connection/direct.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,28 @@ def batch_pick_knl():
425425
)["result"]
426426

427427
else:
428-
batch_result = actx.call_loopy(
429-
batch_pick_knl(),
430-
pick_list=point_pick_indices,
431-
ary=ary[batch.from_group_index],
432-
from_element_indices=batch._global_from_element_indices(
433-
actx, self.to_discr.groups[i_tgrp]),
434-
n_to_nodes=self.to_discr.groups[i_tgrp].nunit_dofs
435-
)["result"]
428+
if actx.permits_advanced_indexing:
429+
from_vec = ary[batch.from_group_index]
430+
from_element_indices = actx.thaw(
431+
batch._global_from_element_indices(
432+
actx, self.to_discr.groups[i_tgrp])
433+
).reshape((-1, 1))
434+
pick_list = actx.thaw(point_pick_indices)
435+
batch_result = actx.np.where(
436+
actx.np.not_equal(from_element_indices, -1),
437+
from_vec[from_element_indices, pick_list],
438+
0)
439+
assert batch_result.shape == (from_element_indices.size,
440+
pick_list.size)
441+
else:
442+
batch_result = actx.call_loopy(
443+
batch_pick_knl(),
444+
pick_list=point_pick_indices,
445+
ary=ary[batch.from_group_index],
446+
from_element_indices=batch._global_from_element_indices(
447+
actx, self.to_discr.groups[i_tgrp]),
448+
n_to_nodes=self.to_discr.groups[i_tgrp].nunit_dofs
449+
)["result"]
436450

437451
batched_data.append(batch_result)
438452

0 commit comments

Comments
 (0)