Skip to content

Commit 14b635c

Browse files
Validate generic index values
1 parent 78cc4c3 commit 14b635c

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

pymc_extras/statespace/utils/data_tools.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,33 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
9595
index = data.index
9696
return data.values, index
9797

98-
elif isinstance(data.index, pd.Index):
98+
elif isinstance(data.index, pd.RangeIndex):
9999
if obs_coords is not None:
100100
warnings.warn(NO_TIME_INDEX_WARNING)
101101
return preprocess_numpy_data(data.values, n_obs, obs_coords)
102102

103+
elif isinstance(data.index, pd.MultiIndex):
104+
if obs_coords is not None:
105+
warnings.warn(NO_TIME_INDEX_WARNING)
106+
107+
raise NotImplementedError("MultiIndex panel data is not currently supported.")
108+
103109
else:
104-
raise IndexError(
105-
f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
106-
)
110+
if obs_coords is not None:
111+
warnings.warn(NO_TIME_INDEX_WARNING)
112+
113+
index = data.index
114+
if not np.issubdtype(index.dtype, np.integer):
115+
raise IndexError("Provided index is not an integer index.")
116+
117+
if not index.is_monotonic_increasing:
118+
raise IndexError("Provided index is not monotonic increasing.")
119+
120+
index_diff = index.to_series().diff().dropna().values
121+
if not (index_diff == 1).all():
122+
raise IndexError("Provided index is not monotonic increasing.")
123+
124+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
107125

108126

109127
def add_data_to_active_model(values, index, data_dims=None):

tests/statespace/test_coord_assignment.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,17 @@ def test_data_index_is_coord(f, warning, create_model):
117117
assert TIME_DIM in pymc_model.coords
118118

119119

120-
def test_integer_index():
121-
a = pd.DataFrame(
122-
index=np.arange(8), columns=["A", "B", "C", "D"], data=np.arange(32).reshape(8, 4)
123-
)
120+
def make_model(index):
121+
n = len(index)
122+
a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4))
124123

125124
mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
126125
ss_mod = mod.build(name="a", verbose=False)
127126

128127
initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
129128
coords = ss_mod.coords
130129

131-
with pm.Model(coords=coords) as model_1:
130+
with pm.Model(coords=coords) as model:
132131
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
133132
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
134133

@@ -140,6 +139,43 @@ def test_integer_index():
140139
a["A"],
141140
mode="JAX",
142141
)
142+
return model
143+
144+
145+
def test_integer_index():
146+
index = np.arange(8).astype(int)
147+
model = make_model(index)
148+
assert TIME_DIM in model.coords
149+
np.testing.assert_allclose(model.coords[TIME_DIM], index)
150+
151+
152+
def test_float_index_raises():
153+
index = np.linspace(0, 1, 8)
154+
155+
with pytest.raises(IndexError, match="Provided index is not an integer index"):
156+
make_model(index)
157+
158+
159+
def test_non_strictly_monotone_index_raises():
160+
# Decreases
161+
index = [0, 1, 2, 1, 2, 3]
162+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
163+
make_model(index)
164+
165+
# Has gaps
166+
index = [0, 1, 2, 3, 5, 6]
167+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
168+
make_model(index)
169+
170+
# Has duplicates
171+
index = [0, 1, 1, 2, 3, 4]
172+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
173+
make_model(index)
174+
143175

144-
assert TIME_DIM in model_1.coords
145-
np.testing.assert_allclose(model_1.coords[TIME_DIM], a.index)
176+
def test_multiindex_raises():
177+
index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)])
178+
with pytest.raises(
179+
NotImplementedError, match="MultiIndex panel data is not currently supported"
180+
):
181+
make_model(index)

0 commit comments

Comments
 (0)