Skip to content

Commit b084a3e

Browse files
Add concat_with_broadcast
1 parent 7c3820b commit b084a3e

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.scalar import upcast
2828
from pytensor.tensor import TensorLike, as_tensor_variable
2929
from pytensor.tensor import basic as ptb
30-
from pytensor.tensor.basic import alloc, second
30+
from pytensor.tensor.basic import alloc, join, second
3131
from pytensor.tensor.exceptions import NotScalarConstantError
3232
from pytensor.tensor.math import abs as pt_abs
3333
from pytensor.tensor.math import all as pt_all
@@ -2018,6 +2018,31 @@ def broadcast_with_others(a, others):
20182018
return brodacasted_vars
20192019

20202020

2021+
def concat_with_broadcast(tensor_list, dim=0):
2022+
"""
2023+
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
2024+
"""
2025+
dim = dim if dim > 0 else tensor_list[0].ndim + dim
2026+
non_concat_shape = [None] * tensor_list[0].ndim
2027+
for tensor_inp in tensor_list:
2028+
for i, (bcast, sh) in enumerate(
2029+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
2030+
):
2031+
if bcast or i == dim or non_concat_shape[i] is not None:
2032+
continue
2033+
non_concat_shape[i] = sh
2034+
2035+
assert non_concat_shape.count(None) == 1
2036+
2037+
bcast_tensor_inputs = []
2038+
for tensor_inp in tensor_list:
2039+
# We modify the concat_axis in place, as we don't need the list anywhere else
2040+
non_concat_shape[dim] = tensor_inp.shape[dim]
2041+
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
2042+
2043+
return join(dim, *bcast_tensor_inputs)
2044+
2045+
20212046
__all__ = [
20222047
"searchsorted",
20232048
"cumsum",
@@ -2035,6 +2060,7 @@ def broadcast_with_others(a, others):
20352060
"ravel_multi_index",
20362061
"broadcast_shape",
20372062
"broadcast_to",
2063+
"concat_with_broadcast",
20382064
"geomspace",
20392065
"logspace",
20402066
"linspace",

tests/tensor/test_extra_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,3 +1333,19 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
13331333
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
13341334
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
13351335
)
1336+
1337+
1338+
def test_concat_with_broadcast():
1339+
rng = np.random.default_rng()
1340+
a = pt.tensor("a", shape=(1, 3, 5))
1341+
b = pt.tensor("b", shape=(5, 3, 10))
1342+
1343+
c = pt.concat_with_broadcast([a, b], dim=-1)
1344+
fn = function([a, b], c, mode="FAST_COMPILE")
1345+
assert c.type.shape == (5, 3, 15)
1346+
1347+
a_val = rng.normal(size=(1, 3, 5))
1348+
b_val = rng.normal(size=(5, 3, 10))
1349+
1350+
c_val = fn(a_val, b_val)
1351+
np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1)))

0 commit comments

Comments
 (0)