27
27
from pytensor .scalar import upcast
28
28
from pytensor .tensor import TensorLike , as_tensor_variable
29
29
from pytensor .tensor import basic as ptb
30
- from pytensor .tensor .basic import alloc , second
30
+ from pytensor .tensor .basic import alloc , join , second
31
31
from pytensor .tensor .exceptions import NotScalarConstantError
32
32
from pytensor .tensor .math import abs as pt_abs
33
33
from pytensor .tensor .math import all as pt_all
@@ -2018,6 +2018,31 @@ def broadcast_with_others(a, others):
2018
2018
return brodacasted_vars
2019
2019
2020
2020
2021
+ def concat_with_broadcast (tensor_list , dim = 0 ):
2022
+ """
2023
+ Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
2024
+ """
2025
+ dim = dim if dim > 0 else tensor_list [0 ].ndim + dim
2026
+ non_concat_shape = [None ] * tensor_list [0 ].ndim
2027
+ for tensor_inp in tensor_list :
2028
+ for i , (bcast , sh ) in enumerate (
2029
+ zip (tensor_inp .type .broadcastable , tensor_inp .shape )
2030
+ ):
2031
+ if bcast or i == dim or non_concat_shape [i ] is not None :
2032
+ continue
2033
+ non_concat_shape [i ] = sh
2034
+
2035
+ assert non_concat_shape .count (None ) == 1
2036
+
2037
+ bcast_tensor_inputs = []
2038
+ for tensor_inp in tensor_list :
2039
+ # We modify the concat_axis in place, as we don't need the list anywhere else
2040
+ non_concat_shape [dim ] = tensor_inp .shape [dim ]
2041
+ bcast_tensor_inputs .append (broadcast_to (tensor_inp , non_concat_shape ))
2042
+
2043
+ return join (dim , * bcast_tensor_inputs )
2044
+
2045
+
2021
2046
__all__ = [
2022
2047
"searchsorted" ,
2023
2048
"cumsum" ,
@@ -2035,6 +2060,7 @@ def broadcast_with_others(a, others):
2035
2060
"ravel_multi_index" ,
2036
2061
"broadcast_shape" ,
2037
2062
"broadcast_to" ,
2063
+ "concat_with_broadcast" ,
2038
2064
"geomspace" ,
2039
2065
"logspace" ,
2040
2066
"linspace" ,
0 commit comments