Skip to content

Commit c15f965

Browse files
Save static shape of last data dim
1 parent c9458e7 commit c15f965

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pymc_extras/statespace/utils/data_tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ def add_data_to_active_model(values, index, data_dims=None):
141141

142142
# If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will
143143
# raise a broadcasting error.
144-
data_shape = None
145-
if values.shape[-1] == 1:
144+
if values.shape[-1] == 1 or values.ndim == 1:
146145
data_shape = (None, 1)
146+
else:
147+
data_shape = (None, values.shape[-1])
147148

148149
data = pm.Data("data", values, dims=data_dims, shape=data_shape)
149150

0 commit comments

Comments
 (0)