2727from pytensor .scalar import upcast
2828from pytensor .tensor import TensorLike , as_tensor_variable
2929from pytensor .tensor import basic as ptb
30- from pytensor .tensor .basic import alloc , second
30+ from pytensor .tensor .basic import alloc , join , second
3131from pytensor .tensor .exceptions import NotScalarConstantError
3232from pytensor .tensor .math import abs as pt_abs
3333from pytensor .tensor .math import all as pt_all
@@ -2018,6 +2018,31 @@ def broadcast_with_others(a, others):
20182018 return brodacasted_vars
20192019
20202020
2021+ def concat_with_broadcast (tensor_list , dim = 0 ):
2022+ """
2023+ Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
2024+ """
2025+ dim = dim if dim > 0 else tensor_list [0 ].ndim + dim
2026+ non_concat_shape = [None ] * tensor_list [0 ].ndim
2027+ for tensor_inp in tensor_list :
2028+ for i , (bcast , sh ) in enumerate (
2029+ zip (tensor_inp .type .broadcastable , tensor_inp .shape )
2030+ ):
2031+ if bcast or i == dim or non_concat_shape [i ] is not None :
2032+ continue
2033+ non_concat_shape [i ] = sh
2034+
2035+ assert non_concat_shape .count (None ) == 1
2036+
2037+ bcast_tensor_inputs = []
2038+ for tensor_inp in tensor_list :
2039+ # We modify the concat_axis in place, as we don't need the list anywhere else
2040+ non_concat_shape [dim ] = tensor_inp .shape [dim ]
2041+ bcast_tensor_inputs .append (broadcast_to (tensor_inp , non_concat_shape ))
2042+
2043+ return join (dim , * bcast_tensor_inputs )
2044+
2045+
20212046__all__ = [
20222047 "searchsorted" ,
20232048 "cumsum" ,
@@ -2035,6 +2060,7 @@ def broadcast_with_others(a, others):
20352060 "ravel_multi_index" ,
20362061 "broadcast_shape" ,
20372062 "broadcast_to" ,
2063+ "concat_with_broadcast" ,
20382064 "geomspace" ,
20392065 "logspace" ,
20402066 "linspace" ,
0 commit comments