-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
pytato/pytato/transform/lower_to_index_lambda.py
Lines 140 to 186 in 5aa8aa3
| def map_concatenate(self, expr: Concatenate) -> IndexLambda: | |
| from pymbolic.primitives import If, Comparison, Subscript | |
| def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: | |
| aggregate = prim.Variable(f"_in{array_index}") | |
| index = [prim.Variable(f"_{i}") | |
| if i != expr.axis | |
| else (prim.Variable(f"_{i}") - offset) | |
| for i in range(len(expr.shape))] | |
| return Subscript(aggregate, tuple(index)) | |
| lbounds: List[Any] = [0] | |
| ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]] | |
| for i, array in enumerate(expr.arrays[1:], start=1): | |
| ubounds.append(ubounds[i-1]+array.shape[expr.axis]) | |
| lbounds.append(ubounds[i-1]) | |
| # I = axis index | |
| # | |
| # => If(_I < arrays[0].shape[axis], | |
| # _in0[_0, _1, ..., _I, ...], | |
| # If(_I < (arrays[1].shape[axis]+arrays[0].shape[axis]), | |
| # _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...], | |
| # ... | |
| # _inNm1[_0, _1, ...] ...)) | |
| for i in range(len(expr.arrays) - 1, -1, -1): | |
| lbound, ubound = lbounds[i], ubounds[i] | |
| subarray_expr = get_subscript(i, lbound) | |
| if i == len(expr.arrays) - 1: | |
| concat_expr = subarray_expr | |
| else: | |
| concat_expr = If(Comparison(prim.Variable(f"_{expr.axis}"), | |
| "<", ubound), | |
| subarray_expr, | |
| concat_expr) | |
| bindings = {f"_in{i}": self.rec(array) | |
| for i, array in enumerate(expr.arrays)} | |
| return IndexLambda(expr=concat_expr, | |
| shape=self._rec_shape(expr.shape), | |
| dtype=expr.dtype, | |
| bindings=immutabledict(bindings), | |
| axes=expr.axes, | |
| var_to_reduction_descr=immutabledict(), | |
| tags=expr.tags) |
cc @majosm
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels