Skip to content

Code generation for concatenate should be tree-of-ifs, not chain-of-ifs #495

@inducer

Description

@inducer

def map_concatenate(self, expr: Concatenate) -> IndexLambda:
from pymbolic.primitives import If, Comparison, Subscript
def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript:
aggregate = prim.Variable(f"_in{array_index}")
index = [prim.Variable(f"_{i}")
if i != expr.axis
else (prim.Variable(f"_{i}") - offset)
for i in range(len(expr.shape))]
return Subscript(aggregate, tuple(index))
lbounds: List[Any] = [0]
ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]]
for i, array in enumerate(expr.arrays[1:], start=1):
ubounds.append(ubounds[i-1]+array.shape[expr.axis])
lbounds.append(ubounds[i-1])
# I = axis index
#
# => If(_I < arrays[0].shape[axis],
# _in0[_0, _1, ..., _I, ...],
# If(_I < (arrays[1].shape[axis]+arrays[0].shape[axis]),
# _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...],
# ...
# _inNm1[_0, _1, ...] ...))
for i in range(len(expr.arrays) - 1, -1, -1):
lbound, ubound = lbounds[i], ubounds[i]
subarray_expr = get_subscript(i, lbound)
if i == len(expr.arrays) - 1:
concat_expr = subarray_expr
else:
concat_expr = If(Comparison(prim.Variable(f"_{expr.axis}"),
"<", ubound),
subarray_expr,
concat_expr)
bindings = {f"_in{i}": self.rec(array)
for i, array in enumerate(expr.arrays)}
return IndexLambda(expr=concat_expr,
shape=self._rec_shape(expr.shape),
dtype=expr.dtype,
bindings=immutabledict(bindings),
axes=expr.axes,
var_to_reduction_descr=immutabledict(),
tags=expr.tags)

cc @majosm

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions