-
Notifications
You must be signed in to change notification settings - Fork 88
Description
#473 should now be merged, but as an experimental new feature. It introduced a bunch of technical debt because the feature branch (off main) was done a long time ago and we've had a lot of refactoring since. Details below...
- Delete the file
BSTS_REFACTORING_CONCERNS.md
Overview
The BSTS (Bayesian Structural Time Series) feature branch adds two new model classes (BayesianBasisExpansionTimeSeries and StateSpaceTimeSeries) and modifies the InterruptedTimeSeries experiment class to support them. While the implementation is functional, there are significant deviations from the established patterns in CausalPy that reduce maintainability and violate key design principles.
This document outlines the major concerns and proposes solutions to align the BSTS implementation with CausalPy's architecture.
π¨ Critical Issues
1. API Inconsistency - Data Type Signatures (pymc_models.py)
Problem:
The new model classes break the established contract that all PyMCModel subclasses accept xr.DataArray:
# Existing pattern (all other models)
def build_model(self, X: xr.DataArray, y: xr.DataArray, coords: Dict[str, Any] | None)
def fit(self, X: xr.DataArray, y: xr.DataArray, coords: Dict[str, Any] | None)
# New BSTS models
def build_model(self, X: Optional[np.ndarray], y: np.ndarray, coords: Dict[str, Any] | None)
def fit(self, X: Optional[np.ndarray], y: np.ndarray, coords: Dict[str, Any] | None)Impact:
- Violates Liskov Substitution Principle
- Forces experiment classes to use
isinstance()checks and data conversions - Makes the API unpredictable for users
- Breaks polymorphism
Evidence:
interrupted_time_series.py:163-164: Complex data conversion logicinterrupted_time_series.py:157-158, 185-186, 204-205, 222-223, 246-247: Five repeated type checks
2. Missing treated_units Dimension (pymc_models.py)
Problem:
BSTS models omit the treated_units dimension that all other models include:
# Existing pattern
mu = pm.Deterministic("mu", ..., dims=["obs_ind", "treated_units"])
# New BSTS models
mu = pm.Deterministic("mu", mu_, dims="obs_ind") # Missing treated_units!Impact:
- Breaks the base class
score()method (line 333 expectstreated_units) - Breaks the base class
_data_setter()(lines 220-223 expecttreated_units) - Forces complete override of
score()in both model classes - Requires defensive checks throughout experiment plotting code
Evidence:
pymc_models.py:1412, 1417: BSTS models usedims="obs_ind"onlyinterrupted_time_series.py:319-321, 344-348, 369-371: ~15 conditional checks fortreated_unitsin plottinginterrupted_time_series.py:407-410, 432-433, 436-439: ~8hasattrchecks in data extraction
3. Return Type Inconsistency (pymc_models.py)
Problem:
StateSpaceTimeSeries.predict() returns xr.Dataset instead of az.InferenceData:
# Base class contract
def predict(self, X: xr.DataArray, ...) -> az.InferenceData
# StateSpaceTimeSeries violation
def predict(self, X: Optional[np.ndarray], ...) -> xr.Dataset # Line 1811Impact:
- Breaks polymorphism
- Requires defensive wrapping in experiment class (lines 213-214, 235-238)
- Users can't reliably use
.predict()without checking instance types
Evidence:
# interrupted_time_series.py:213-214, 235-238
if not isinstance(self.pre_pred, az.InferenceData):
self.pre_pred = az.InferenceData(posterior_predictive=self.pre_pred)4. Code Duplication - Repeated Type Checks (interrupted_time_series.py)
Problem:
The same isinstance() check is repeated 5 times in __init__:
# Lines 157-158, 185-186, 204-205, 222-223, 246-247
is_bsts_like = isinstance(
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
)Impact:
- Violates DRY (Don't Repeat Yourself) principle
- Creates maintenance burden - changes require updating 5 places
- Makes code harder to read and follow
Comparison:
Other experiment classes (DifferenceInDifferences, SyntheticControl, PrePostNEGD) do ONE type check:
if isinstance(self.model, PyMCModel):
# PyMC logic
elif isinstance(self.model, RegressorMixin):
# SKL logic5. Violation of Open/Closed Principle (interrupted_time_series.py)
Problem:
The experiment class imports and explicitly checks for specific model types:
from causalpy.pymc_models import (
BayesianBasisExpansionTimeSeries, # β Tight coupling
PyMCModel,
StateSpaceTimeSeries, # β Tight coupling
)Impact:
- Adding new time-series models requires modifying the experiment class
- Breaks the abstraction provided by the
PyMCModelbase class - Violates Open/Closed Principle (open for extension, closed for modification)
Comparison:
Other experiment files only import base classes:
# diff_in_diff.py, synthetic_control.py, etc.
from causalpy.pymc_models import PyMCModelβ οΈ Major Issues
6. Special Coordinate Requirements (pymc_models.py)
Problem:
BSTS models require datetime_index as pd.DatetimeIndex in coords, and pop it from the dictionary:
# Line 1281 (BayesianBasisExpansionTimeSeries)
datetime_index = coords.pop("datetime_index", None)Impact:
- Makes API less predictable
datetime_indexis not preserved in model coordinates- Users must know special requirements for these models
Standard Pattern:
# Standard coords
{"coeffs": [...], "obs_ind": [...], "treated_units": [...]}7. Non-Standard Model Context (pymc_models.py)
Problem:
StateSpaceTimeSeries creates a separate model context instead of using self:
# Existing pattern
with self: # Use the PyMCModel instance as context
self.add_coords(coords)
# ... model definition
# StateSpaceTimeSeries (Line 1717-1736)
with pm.Model(coords=coordinates) as self.second_model:
# ... model definitionImpact:
- Confusing because
StateSpaceTimeSeriesinherits frompm.Model - Breaks Liskov Substitution Principle
- Methods expecting
with self:won't work correctly - Creates maintenance complexity
8. No Prior Configuration System (pymc_models.py)
Problem:
BSTS models don't use the standard default_priors system:
# Existing pattern
default_priors = {
"beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]),
...
}
# BSTS models - hard-coded priors
beta = pm.Normal("beta", mu=0, sigma=10, dims="coeffs") # Line 1408
sigma = pm.HalfNormal("sigma", sigma=self.prior_sigma) # Line 1415Impact:
- Users can't customize priors using the standard Prior system
- Only
prior_sigmais configurable via__init__ - Inconsistent with established patterns
9. Complex _data_setter() Override (pymc_models.py)
Problem:
BayesianBasisExpansionTimeSeries._data_setter() has a different signature:
# Base class
def _data_setter(self, X: xr.DataArray) -> None
# BayesianBasisExpansionTimeSeries (Line 1456-1536)
def _data_setter(self, X_pred: Optional[np.ndarray], coords_pred: Dict[str, Any]) -> NoneImpact:
- Signature doesn't match base class
- Base
predict()can't call it correctly - Forces complete override of
predict()
10. Extensive Conditional Logic in Plotting (interrupted_time_series.py)
Problem:
Plotting methods have ~15 conditional checks for treated_units dimension:
# Lines 319-321, 344-348, 369-371, etc.
pre_mu_plot = (
pre_mu.isel(treated_units=0) if "treated_units" in pre_mu.dims else pre_mu
)Impact:
- Makes plotting code verbose and hard to read
- Other plotting methods don't need this complexity
- Suggests data format should be standardized earlier
11. Inconsistent Data Handling Pattern (interrupted_time_series.py)
Problem:
Experiment stores data as xarray, then converts to numpy for BSTS:
# Lines 163-164
X_fit = self.pre_X.values if self.pre_X.shape[1] > 0 else None
y_fit = self.pre_y.isel(treated_units=0).valuesImpact:
- Data stored in one format but used in another
- Conversion logic is complex and error-prone
- Complex conditional:
if self.pre_X.shape[1] > 0 else None
Standard Pattern:
# synthetic_control.py, lines 152-156
self.model.fit(
X=self.datapre_control, # β xarray passed directly
y=self.datapre_treated,
coords=COORDS,
)12. State Management Complexity (pymc_models.py)
Problem:
BayesianBasisExpansionTimeSeries maintains hidden state:
# Line 1110, 1111
self._first_fit_timestamp: Optional[pd.Timestamp] = None
self._exog_var_names: Optional[List[str]] = None
# Line 1247
if self._first_fit_timestamp is None:
self._first_fit_timestamp = datetime_index[0]Impact:
- Makes model stateful in non-obvious ways
- First call to
fit()permanently sets_first_fit_timestamp - Subsequent predictions use this for time calculations
- No clear way to reset the model
π§ Proposed Solutions
Solution 1: Create TimeSeriesPyMCModel Abstract Base Class
Approach:
Create a new abstract base class that handles time-series-specific requirements:
class TimeSeriesPyMCModel(PyMCModel):
"""Base class for time series models with datetime indices."""
def build_model(
self,
X: Optional[np.ndarray],
y: np.ndarray,
coords: Dict[str, Any]
) -> None:
"""
Time series models use numpy arrays and require datetime_index in coords.
Parameters
----------
X : np.ndarray or None
Exogenous variables
y : np.ndarray
Target variable (1D)
coords : dict
Must contain "datetime_index" (pd.DatetimeIndex)
"""
raise NotImplementedError
def fit(
self,
X: Optional[np.ndarray],
y: np.ndarray,
coords: Dict[str, Any]
) -> az.InferenceData:
"""Fit time series model."""
raise NotImplementedError
# Add time-series specific helper methods
def _validate_datetime_index(self, coords: Dict[str, Any]) -> pd.DatetimeIndex:
"""Extract and validate datetime index from coords."""
...Benefits:
- Clear separation between standard and time-series models
- Experiment classes can use
isinstance(model, TimeSeriesPyMCModel)once - Documents the different requirements
- Allows future time-series models to extend easily
Solution 2: Add treated_units Dimension to BSTS Models
Approach:
Modify BSTS models to always include treated_units=["unit_0"]:
# In build_model()
model_coords = {
"obs_ind": np.arange(num_obs),
"treated_units": ["unit_0"], # β Add this
}
# Update mu definition
mu = pm.Deterministic("mu", mu_, dims=["obs_ind", "treated_units"]) # β Add treated_unitsBenefits:
- Maintains consistency with other models
- Base class methods work without modification
- Eliminates ~23 conditional checks in experiment class
- Simpler plotting code
Trade-offs:
- Slightly more complex for truly univariate models
- But improves overall consistency
Solution 3: Standardize Return Types
Approach:
Make StateSpaceTimeSeries.predict() return az.InferenceData:
def predict(self, ...) -> az.InferenceData:
# ... existing logic ...
# Wrap result in InferenceData before returning
result = az.InferenceData(posterior_predictive={
"y_hat": y_hat_final,
"mu": y_hat_final,
})
return resultBenefits:
- Maintains polymorphism
- No defensive wrapping needed in experiment class
- Users can rely on consistent API
Solution 4: Refactor Experiment Class to Reduce Duplication
Approach:
Extract repeated logic into helper methods:
class InterruptedTimeSeries(BaseExperiment):
def __init__(self, ...):
super().__init__(model=model)
# ... setup ...
# Single type check
self._is_timeseries_model = isinstance(
self.model, TimeSeriesPyMCModel # Or use ABC
)
# Extract to methods
self._fit_model()
self._score_model()
self._predict_pre_period()
self._predict_post_period()
self._calculate_impacts()
def _prepare_data_for_model(self, X: xr.DataArray, y: xr.DataArray):
"""Handle data format conversion in one place."""
if self._is_timeseries_model:
return self._convert_to_timeseries_format(X, y)
return X, y
def _convert_to_timeseries_format(self, X, y):
"""Convert xarray to format expected by time series models."""
X_numpy = X.values if X.shape[1] > 0 else None
y_numpy = y.isel(treated_units=0).values
return X_numpy, y_numpyBenefits:
- Reduces duplication from 5 checks to 1
- Centralizes conversion logic
- Easier to test
- More maintainable
Solution 5: Implement Standard Prior System
Approach:
Add default_priors to BSTS models:
class BayesianBasisExpansionTimeSeries(PyMCModel):
default_priors = {
"beta": Prior("Normal", mu=0, sigma=10, dims="coeffs"),
"sigma": Prior("HalfNormal", sigma=5),
}
def __init__(self, ..., priors: dict[str, Any] | None = None):
super().__init__(sample_kwargs=sample_kwargs, priors=priors)
# ... rest of init ...
def build_model(self, ...):
# Use self.priors instead of hard-coded values
beta = self.priors["beta"].create_variable("beta")
sigma = self.priors["sigma"].create_variable("sigma")Benefits:
- Users can customize priors using standard system
- Consistent with other models
- Better defaults documented in one place
Solution 6: Add Helper Method for Model Context
Approach:
For StateSpaceTimeSeries, document why separate context is needed:
class StateSpaceTimeSeries(PyMCModel):
"""
Note: This model uses a separate PyMC Model context (self.second_model)
instead of self due to requirements of the state-space implementation.
This is necessary for pymc-extras state-space models.
"""
def build_model(self, ...):
# Current approach, but with clear documentation
with pm.Model(coords=coordinates) as self.second_model:
...Or if possible, refactor to use self:
def build_model(self, ...):
with self:
self.add_coords(coordinates)
# ... build state-space model within self contextπ Implementation Plan
Phase 1: Quick Wins (Low Risk, High Impact)
- β Add experimental warnings (DONE)
- Extract repeated type check in
InterruptedTimeSeries.__init__to single variable - Add
treated_unitsdimension to BSTS models - Standardize
StateSpaceTimeSeries.predict()return type
Phase 2: API Standardization (Medium Risk, High Impact)
- Create
TimeSeriesPyMCModelabstract base class - Refactor BSTS models to inherit from new base class
- Implement standard prior system in BSTS models
- Update experiment class to use ABC instead of explicit type checks
Phase 3: Code Quality (Low Risk, Medium Impact)
- Extract helper methods in
InterruptedTimeSeriesto reduce duplication - Simplify plotting code (benefits from Phase 1 flesh out the READMEΒ #3)
- Add comprehensive documentation about time-series model requirements
- Add tests for time-series model interface
Phase 4: Advanced Improvements (Optional)
- Consider adapter pattern to wrap BSTS models for xarray compatibility
- Evaluate state management approach in
BayesianBasisExpansionTimeSeries - Document or refactor
StateSpaceTimeSeriesmodel context usage
π― Priority Assessment
| Issue | Priority | Impact | Effort | Phase |
|---|---|---|---|---|
| API Inconsistency (data types) | π΄ Critical | High | Medium | 2 |
Missing treated_units |
π΄ Critical | High | Low | 1 |
| Return Type Inconsistency | π΄ Critical | High | Low | 1 |
| Code Duplication (5x checks) | π΄ Critical | Medium | Low | 1 |
| Open/Closed Violation | π΄ Critical | High | Medium | 2 |
| Special Coordinate Requirements | π‘ Major | Medium | Medium | 2 |
| Non-Standard Model Context | π‘ Major | Medium | High | 4 |
| No Prior Configuration | π‘ Major | Medium | Medium | 2 |
Complex _data_setter() |
π‘ Major | Medium | Medium | 2 |
| Extensive Plotting Conditionals | π‘ Major | Low | Low | 3 |
| Inconsistent Data Handling | π‘ Major | Medium | Low | 3 |
| State Management Complexity | π‘ Major | Low | High | 4 |
π Additional Considerations
Backward Compatibility
- Changes to model APIs will break existing BSTS user code
- Should version as breaking change (e.g., 0.5.0)
- Consider deprecation warnings before removal
Testing Requirements
- Add integration tests for time-series model interface
- Test that experiment class works with all model types
- Add tests for data format conversions
- Test prior customization system
Documentation Needs
- Document time-series model requirements clearly
- Provide migration guide if API changes
- Add examples showing both standard and time-series models
- Document the
TimeSeriesPyMCModelABC if created
π€ Open Questions
-
State-space requirements: Can
StateSpaceTimeSeriesuseselfas context, or does pymc-extras require a separate model? -
Backward compatibility: How many users are already using these experimental models? Should we prioritize backward compatibility or clean API?
-
Time-series ABC: Should
TimeSeriesPyMCModelbe a separate class hierarchy, or should we makePyMCModelmore flexible? -
Data format: Is there value in making BSTS models accept xarray, or is numpy + datetime the right approach for time series?
-
Prior system: Should time-series models support dimension-specific priors like
dims=["obs_ind", "treated_units"]?
π Conclusion
The BSTS implementation adds valuable functionality to CausalPy, but the current approach creates maintenance challenges and API inconsistencies. By following the proposed solutions, we can:
- Maintain the functionality while improving API consistency
- Reduce code duplication and improve maintainability
- Make the codebase more extensible for future time-series models
- Provide a better user experience with consistent interfaces
The experimental warnings currently in place give us breathing room to make breaking changes if needed. We should prioritize Phase 1 quick wins to address the most critical issues, then move to API standardization in Phase 2.