Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pymc_extras/statespace/utils/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
col_names = data.columns
_validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)

if isinstance(data.index, pd.RangeIndex):
if obs_coords is not None:
warnings.warn(NO_TIME_INDEX_WARNING)
return preprocess_numpy_data(data.values, n_obs, obs_coords)

elif isinstance(data.index, pd.DatetimeIndex):
if isinstance(data.index, pd.DatetimeIndex):
if data.index.freq is None:
warnings.warn(NO_FREQ_INFO_WARNING)
data.index.freq = data.index.inferred_freq

index = data.index
return data.values, index

elif isinstance(data.index, pd.Index):
if obs_coords is not None:
warnings.warn(NO_TIME_INDEX_WARNING)
return preprocess_numpy_data(data.values, n_obs, obs_coords)

else:
raise IndexError(
f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
Expand Down
29 changes: 29 additions & 0 deletions tests/statespace/test_coord_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from pymc_extras.statespace.models import structural
from pymc_extras.statespace.models.structural import LevelTrendComponent
from pymc_extras.statespace.utils.constants import (
FILTER_OUTPUT_DIMS,
FILTER_OUTPUT_NAMES,
Expand Down Expand Up @@ -114,3 +115,31 @@ def test_data_index_is_coord(f, warning, create_model):
with warning:
pymc_model = create_model(f)
assert TIME_DIM in pymc_model.coords


def test_integer_index():
a = pd.DataFrame(
index=np.arange(8), columns=["A", "B", "C", "D"], data=np.arange(32).reshape(8, 4)
)

mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
ss_mod = mod.build(name="a", verbose=False)

initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)

initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)

with pytest.warns(UserWarning, match="No time index found on the supplied data"):
ss_mod.build_statespace_graph(
a["A"],
mode="JAX",
)

assert TIME_DIM in model_1.coords
np.testing.assert_allclose(model_1.coords[TIME_DIM], a.index)