-
Notifications
You must be signed in to change notification settings - Fork 149
Description
Description
concatenate requires all non-axis dimensions to match, this is annoying when we want to say pad a dimension with another vector like concatenate([some_matrix, some_vector], axis=-1), We need to manually do something like concatenate([some_matrix, broadcast_to((some_matrix.shape[0], some_vector.shape[0]), some_vector)], axis=-1).
I suggest adding a helper that does the broadcasting automatically. It is basically the same code we need to lower xtensor.concat to tensor operations:
pytensor/pytensor/xtensor/rewriting/shape.py
Lines 68 to 100 in 12213d0
| @register_lower_xtensor | |
| @node_rewriter(tracks=[Concat]) | |
| def lower_concat(fgraph, node): | |
| out_dims = node.outputs[0].type.dims | |
| concat_dim = node.op.dim | |
| concat_axis = out_dims.index(concat_dim) | |
| # Convert input XTensors to Tensors and align batch dimensions | |
| tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs] | |
| # Broadcast non-concatenated dimensions of each input | |
| non_concat_shape = [None] * len(out_dims) | |
| for tensor_inp in tensor_inputs: | |
| # TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime | |
| # I'm running this as "shape_unsafe" to simplify the logic / returned graph | |
| for i, (bcast, sh) in enumerate( | |
| zip(tensor_inp.type.broadcastable, tensor_inp.shape) | |
| ): | |
| if bcast or i == concat_axis or non_concat_shape[i] is not None: | |
| continue | |
| non_concat_shape[i] = sh | |
| assert non_concat_shape.count(None) == 1 | |
| bcast_tensor_inputs = [] | |
| for tensor_inp in tensor_inputs: | |
| # We modify the concat_axis in place, as we don't need the list anywhere else | |
| non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis] | |
| bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) | |
| joined_tensor = join(concat_axis, *bcast_tensor_inputs) | |
| new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) | |
| return [new_out] |
We can refactor that into a helper that works directly on tensor inputs, offer it to users, and reuse in the lowering rewrite.
Call it concat_with_broacast or xconcat?