Skip to content

Commit c1dfd90

Browse files
committed
Reducing size of the file
1 parent 51eb6b2 commit c1dfd90

File tree

7 files changed

+1213
-1493
lines changed

7 files changed

+1213
-1493
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
exclude_types: [svg]
2323
- id: check-yaml
2424
- id: check-added-large-files
25-
exclude: &exclude_pattern 'iv_weak_instruments.ipynb'
25+
exclude: &exclude_pattern '(iv_weak_instruments\.ipynb|cp_covid\.ipynb)'
2626
args: ["--maxkb=1500"]
2727
- repo: https://github.com/astral-sh/ruff-pre-commit
2828
rev: v0.14.0

0.14.0

Whitespace-only changes.

causalpy/experiments/change_point_detection.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,11 @@
8686
--------------
8787
Known treatment time (traditional approach):
8888
89-
>>> result = cp.InterruptedTimeSeries(
89+
>>> result = cp.ChangePointDetection(
9090
... data=df,
91-
... treatment_time=pd.to_datetime("2017-01-01"), # Known intervention
91+
... time_range=None
9292
... formula="y ~ 1 + t + C(month)",
93-
... model=cp.pymc_models.LinearRegression(),
94-
... )
95-
96-
Unknown treatment time (inference approach):
97-
98-
>>> model = cp.pymc_models.InterventionTimeEstimator(treatment_effect_type="level")
99-
>>> result = cp.InterruptedTimeSeries(
100-
... data=df,
101-
... treatment_time=None, # Let model infer the time
102-
... formula="y ~ 1 + t + C(month)",
103-
... model=model,
93+
... model=cp.pymc_models.LinearChangePointDetection(),
10494
... )
10595
10696
The module automatically selects the appropriate handler based on the treatment_time
@@ -164,35 +154,35 @@ class ChangePointDetection(BaseExperiment):
164154
... )
165155
"""
166156

167-
expt_type = "Interrupted Time Series"
157+
expt_type = "Change Point Detection"
168158
supports_ols = False
169159
supports_bayes = True
170160

171161
def __init__(
172162
self,
173163
data: pd.DataFrame,
174164
formula: str,
175-
treatment_time_range: Union[Iterable, None] = None,
165+
time_range: Union[Iterable, None] = None,
176166
model=None,
177167
**kwargs,
178168
) -> None:
179169
super().__init__(model=model)
180170

181171
# rename the index to "obs_ind"
182172
data.index.name = "obs_ind"
183-
self.input_validation(data, treatment_time_range, model)
173+
self.input_validation(data, time_range, model)
184174

185175
# set experiment type - usually done in subclasses
186176
self.expt_type = "Pre-Post Fit"
187177

188-
self.treatment_time_range = treatment_time_range
178+
self.time_range = time_range
189179
self.formula = formula
190180

191181
# Define the time interval over which the model will perform inference
192-
model.set_time_range(self.treatment_time_range, data)
182+
model.set_time_range(self.time_range, data)
193183

194184
# Preprocess the data according to the given formula
195-
y, X = dmatrices(formula, self.datapre)
185+
y, X = dmatrices(formula, data)
196186

197187
self.outcome_variable_name = y.design_info.column_names[0]
198188
self._y_design_info = y.design_info
@@ -205,14 +195,14 @@ def __init__(
205195
self.X,
206196
dims=["obs_ind", "coeffs"],
207197
coords={
208-
"obs_ind": self.datapre.index,
198+
"obs_ind": data.index,
209199
"coeffs": self.labels,
210200
},
211201
)
212202
self.y = xr.DataArray(
213203
self.y, # Keep 2D shape
214204
dims=["obs_ind", "treated_units"],
215-
coords={"obs_ind": self.datapre.index, "treated_units": ["unit_0"]},
205+
coords={"obs_ind": data.index, "treated_units": ["unit_0"]},
216206
)
217207

218208
# fit the model to the observed data
@@ -266,34 +256,40 @@ def __init__(
266256
timeline_broadcast = np.array(timeline)
267257
tt_broadcast = cp_samples[:, :, None].astype(int)
268258
mask = (timeline_broadcast >= tt_broadcast).astype(int)
259+
mask = mask[:, :, np.newaxis, :]
260+
post_impact_masked = impact * mask
269261

270-
# --- Compute cumulative post-treatment impact ---
262+
# --- Compute cumulative post-change point impact ---
271263
post_impact_masked = impact * mask
272264
self.post_impact_cumulative = model.calculate_cumulative_impact(
273265
post_impact_masked
274266
)
275267

276-
def input_validation(self, data, treatment_time_range, model):
268+
def input_validation(self, data, time_range, model):
277269
"""Validate the input data and model formula for correctness"""
278270
if not hasattr(model, "set_time_range"):
279271
raise ModelException("Provided model must have a 'set_time_range' method")
280-
if treatment_time_range is not None and len(treatment_time_range) != 2:
272+
if time_range is not None and len(time_range) != 2:
281273
raise BadIndexException(
282-
"Provided treatment_time_range must be of length 2 : (start, end)"
274+
"Provided time_range must be of length 2 : (start, end)"
283275
)
284276
if isinstance(data.index, pd.DatetimeIndex) and not (
285-
treatment_time_range is None
286-
or isinstance(treatment_time_range, Iterable[pd.Timestamp])
277+
time_range is None
278+
or (
279+
isinstance(time_range, Iterable)
280+
and all(isinstance(t, pd.Timestamp) for t in time_range)
281+
)
287282
):
288283
raise BadIndexException(
289-
"If data.index is DatetimeIndex, treatment_time_range must "
284+
"If data.index is DatetimeIndex, time_range must "
290285
"be of type Iterable[pd.Timestamp]."
291286
)
292-
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
293-
treatment_time_range, Iterable[pd.Timestamp]
287+
if not isinstance(data.index, pd.DatetimeIndex) and (
288+
isinstance(time_range, Iterable)
289+
and all(isinstance(t, pd.Timestamp) for t in time_range)
294290
):
295291
raise BadIndexException(
296-
"If data.index is not DatetimeIndex, treatment_time_range must"
292+
"If data.index is not DatetimeIndex, time_range must"
297293
"not be of type Iterable[pd.Timestamp]." # noqa: E501
298294
)
299295

@@ -324,7 +320,7 @@ def _bayesian_plot(
324320
labels = []
325321

326322
# Treated counterfactual
327-
# Plot predicted values under treatment (with HDI)
323+
# Plot predicted values after change point (with HDI)
328324
h_line, h_patch = plot_xY(
329325
self.datapre.index,
330326
self.pre_pred["posterior_predictive"].mu_ts.isel(treated_units=0),
@@ -440,17 +436,17 @@ def _bayesian_plot(
440436
)
441437
ax[2].axhline(y=0, c="k")
442438

443-
# Plot vertical line marking treatment time (with HDI if it's inferred)
439+
# Plot vertical line marking change point (with HDI if it's inferred)
444440
data = pd.concat([self.datapre, self.datapost])
445-
# Extract the HDI (uncertainty interval) of the treatment time
446-
hdi = az.hdi(self.idata, var_names=["treatment_time"])["treatment_time"].values
441+
# Extract the HDI (uncertainty interval) of the change point
442+
hdi = az.hdi(self.idata, var_names=["change_point"])["change_point"].values
447443
x1 = data.index[int(hdi[0])]
448444
x2 = data.index[int(hdi[1])]
449445

450446
for i in [0, 1, 2]:
451447
ymin, ymax = ax[i].get_ylim()
452448

453-
# Vertical line for inferred treatment time
449+
# Vertical line for inferred change point
454450
ax[i].plot(
455451
[self.changepoint, self.changepoint],
456452
[ymin, ymax],
@@ -460,7 +456,7 @@ def _bayesian_plot(
460456
solid_capstyle="butt",
461457
)
462458

463-
# Shaded region for HDI of treatment time
459+
# Shaded region for HDI of change point
464460
ax[i].fill_betweenx(
465461
y=[ymin, ymax],
466462
x1=x1,
@@ -545,13 +541,13 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
545541
else:
546542
raise ValueError("Unsupported model type")
547543

548-
def plot_treatment_time(self):
544+
def plot_change_point(self):
549545
"""
550-
display the posterior estimates of the treatment time
546+
display the posterior estimates of the change point
551547
"""
552-
if "treatment_time" not in self.idata.posterior.data_vars:
548+
if "change_point" not in self.idata.posterior.data_vars:
553549
raise ValueError(
554-
"Variable 'treatment_time' not found in inference data (idata)."
550+
"Variable 'change_point' not found in inference data (idata)."
555551
)
556552

557-
az.plot_trace(self.idata, var_names="treatment_time")
553+
az.plot_trace(self.idata, var_names="change_point")

causalpy/pymc_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ class LinearChangePointDetection(PyMCModel):
931931
>>> import causalpy as cp
932932
>>> import numpy as np
933933
>>> from patsy import build_design_matrices, dmatrices
934-
>>> from causalpy.pymc_models import InterventionTimeEstimator as ITE
934+
>>> from causalpy.pymc_models import LinearChangePointDetection
935935
>>> data = cp.load_data("its")
936936
>>> formula="y ~ 1 + t + C(month)"
937937
>>> y, X = dmatrices(formula, data)
@@ -959,7 +959,7 @@ class LinearChangePointDetection(PyMCModel):
959959
... "obs_ind": np.arange(X.shape[0]),
960960
... "treated_units": ["unit_0"],
961961
... }
962-
>>> model = ITE(treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
962+
>>> model = LinearChangePointDetection(cp_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
963963
>>> model.set_time_range(None, data)
964964
>>> model.fit(X=_X, y=_y, coords=COORDS)
965965
Inference ...
@@ -974,7 +974,7 @@ def __init__(
974974
"""
975975
Initializes the InterventionTimeEstimator model.
976976
977-
:param treatment_effect_type: Optional dictionary that specifies prior parameters for the
977+
:param cp_effect_type: Optional dictionary that specifies prior parameters for the
978978
intervention effects. Expected keys are:
979979
- "level": [mu, sigma]
980980
- "trend": [mu, sigma]
@@ -1039,7 +1039,7 @@ def build_model(self, X, y, coords):
10391039
Assumes the following attributes are already defined in self:
10401040
- self.timeline: the index of the column in X representing time.
10411041
- self.time_range: a tuple (lower_bound, upper_bound) for the intervention time.
1042-
- self.treatment_effect_type: a dictionary specifying which intervention effects to include and their priors.
1042+
- self.cp_effect_type: a dictionary specifying which intervention effects to include and their priors.
10431043
"""
10441044

10451045
with self:

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,13 @@ def test_its(mock_pymc_sample):
404404

405405

406406
@pytest.mark.integration
407-
def test_its_no_treatment_time():
407+
def test_cp_covid():
408408
"""
409-
Test Interrupted Time-Series experiment on COVID data with an unknown treatment time.
409+
Test ChangePoint experiment on COVID data.
410410
411411
Loads data and checks:
412412
1. data is a dataframe
413-
2. causalpy.InterruptedtimeSeries returns correct type
413+
2. causalpy.ChangePoint returns correct type
414414
3. the correct number of MCMC chains exists in the posterior inference data
415415
4. the correct number of MCMC draws exists in the posterior inference data
416416
5. the method get_plot_data returns a DataFrame with expected columns
@@ -421,26 +421,23 @@ def test_its_no_treatment_time():
421421
.assign(date=lambda x: pd.to_datetime(x["date"]))
422422
.set_index("date")
423423
)
424-
treatment_time = (pd.to_datetime("2014-01-01"), pd.to_datetime("2022-01-01"))
424+
time_range = (pd.to_datetime("2014-01-01"), pd.to_datetime("2022-01-01"))
425425

426426
# Assert that we correctfully raise a ModelException if the given model can't predict InterventionTime
427427
with pytest.raises(cp.custom_exceptions.ModelException) as exc_info:
428-
cp.InterruptedTimeSeries(
428+
cp.ChangePointDetection(
429429
df,
430-
treatment_time,
430+
time_range=time_range,
431431
formula="standardize(deaths) ~ 0 + t + C(month) + standardize(temp)", # noqa E501
432432
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
433433
)
434-
assert (
435-
"If treatment_time is a tuple, provided model must have a 'set_time_range' method"
436-
in str(exc_info.value)
437-
)
434+
assert "Provided model must have a 'set_time_range' method" in str(exc_info.value)
438435

439-
result = cp.InterruptedTimeSeries(
436+
result = cp.ChangePointDetection(
440437
df,
441-
treatment_time,
438+
time_range=time_range,
442439
formula="standardize(deaths) ~ 0 + t + C(month) + standardize(temp)", # noqa E501
443-
model=cp.pymc_models.InterventionTimeEstimator(
440+
model=cp.pymc_models.LinearChangePointDetection(
444441
treatment_effect_type=["impulse", "level", "trend"],
445442
sample_kwargs=sample_kwargs,
446443
),

docs/source/notebooks/cp_covid.ipynb

Lines changed: 1160 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/its_no_treatment_time.ipynb

Lines changed: 0 additions & 1433 deletions
This file was deleted.

0 commit comments

Comments
 (0)