Skip to content

Commit 78cc4c3

Browse files
Make index check less strict
1 parent 49e2818 commit 78cc4c3

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

pymc_extras/statespace/utils/data_tools.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,19 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
8787
col_names = data.columns
8888
_validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
8989

90-
if isinstance(data.index, pd.RangeIndex):
91-
if obs_coords is not None:
92-
warnings.warn(NO_TIME_INDEX_WARNING)
93-
return preprocess_numpy_data(data.values, n_obs, obs_coords)
94-
95-
elif isinstance(data.index, pd.DatetimeIndex):
90+
if isinstance(data.index, pd.DatetimeIndex):
9691
if data.index.freq is None:
9792
warnings.warn(NO_FREQ_INFO_WARNING)
9893
data.index.freq = data.index.inferred_freq
9994

10095
index = data.index
10196
return data.values, index
10297

98+
elif isinstance(data.index, pd.Index):
99+
if obs_coords is not None:
100+
warnings.warn(NO_TIME_INDEX_WARNING)
101+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
102+
103103
else:
104104
raise IndexError(
105105
f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"

tests/statespace/test_coord_assignment.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from pymc_extras.statespace.models import structural
11+
from pymc_extras.statespace.models.structural import LevelTrendComponent
1112
from pymc_extras.statespace.utils.constants import (
1213
FILTER_OUTPUT_DIMS,
1314
FILTER_OUTPUT_NAMES,
@@ -114,3 +115,31 @@ def test_data_index_is_coord(f, warning, create_model):
114115
with warning:
115116
pymc_model = create_model(f)
116117
assert TIME_DIM in pymc_model.coords
118+
119+
120+
def test_integer_index():
121+
a = pd.DataFrame(
122+
index=np.arange(8), columns=["A", "B", "C", "D"], data=np.arange(32).reshape(8, 4)
123+
)
124+
125+
mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
126+
ss_mod = mod.build(name="a", verbose=False)
127+
128+
initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
129+
coords = ss_mod.coords
130+
131+
with pm.Model(coords=coords) as model_1:
132+
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
133+
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
134+
135+
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
136+
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
137+
138+
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
139+
ss_mod.build_statespace_graph(
140+
a["A"],
141+
mode="JAX",
142+
)
143+
144+
assert TIME_DIM in model_1.coords
145+
np.testing.assert_allclose(model_1.coords[TIME_DIM], a.index)

0 commit comments

Comments
 (0)