Skip to content

Commit ba4ebb6

Browse files
committed
Provide static shape in output of Split
1 parent 0b56ed9 commit ba4ebb6

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

pytensor/tensor/basic.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,8 +2201,26 @@ def make_node(self, x, axis, splits):
22012201
raise TypeError("`axis` parameter must be an integer scalar")
22022202

22032203
inputs = [x, axis, splits]
2204-
out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim)
2205-
outputs = [out_type() for i in range(self.len_splits)]
2204+
2205+
x_dtype = x.type.dtype
2206+
if isinstance(axis, Constant):
2207+
# In this case we can preserve more static shape info
2208+
static_axis = axis.data.item()
2209+
outputs = []
2210+
x_static_shape = list(x.type.shape)
2211+
for i in range(self.len_splits):
2212+
try:
2213+
static_split_size = int(get_scalar_constant_value(splits[i]))
2214+
except NotScalarConstantError:
2215+
static_split_size = None
2216+
static_out_shape = x_static_shape.copy()
2217+
static_out_shape[static_axis] = static_split_size
2218+
outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype))
2219+
else:
2220+
outputs = [
2221+
tensor(shape=(None,) * x.type.ndim, dtype=x_dtype)
2222+
for i in range(self.len_splits)
2223+
]
22062224

22072225
return Apply(self, inputs, outputs)
22082226

0 commit comments

Comments
 (0)