Skip to content

Commit a92e9d4

Browse files
committed
gemini
1 parent b972ae2 commit a92e9d4

File tree

6 files changed

+18
-5
lines changed

6 files changed

+18
-5
lines changed

causalpy/data/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@
4343
}
4444

4545

46-
def _get_data_home() -> pathlib.PosixPath:
46+
def _get_data_home() -> pathlib.Path:
4747
"""Return the path of the data directory"""
4848
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"
4949

5050

51-
def load_data(dataset: str = None) -> pd.DataFrame:
51+
def load_data(dataset: str | None = None) -> pd.DataFrame:
5252
"""Loads the requested dataset and returns a pandas DataFrame.
5353
5454
:param dataset: The desired dataset to load

causalpy/experiments/diff_in_diff.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
**kwargs,
9797
) -> None:
9898
super().__init__(model=model)
99+
self.causal_impact: xr.DataArray | float | None
99100
# rename the index to "obs_ind"
100101
data.index.name = "obs_ind"
101102
self.data = data
@@ -213,6 +214,7 @@ def __init__(
213214

214215
# calculate causal impact
215216
if isinstance(self.model, PyMCModel):
217+
assert self.model.idata is not None
216218
# This is the coefficient on the interaction term
217219
coeff_names = self.model.idata.posterior.coords["coeffs"].data
218220
for i, label in enumerate(coeff_names):
@@ -395,7 +397,7 @@ def _plot_causal_impact_arrow(results, ax):
395397
labels = ["Control group"]
396398

397399
# Plot model fit to treatment group
398-
time_points = self.x_pred_control[self.time_variable_name].values
400+
time_points = self.x_pred_treatment[self.time_variable_name].values
399401
h_line, h_patch = plot_xY(
400402
time_points,
401403
self.y_pred_treatment["posterior_predictive"].mu.isel(treated_units=0),

causalpy/experiments/interrupted_time_series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def __init__(
9393
**kwargs,
9494
) -> None:
9595
super().__init__(model=model)
96+
self.pre_y: xr.DataArray
97+
self.post_y: xr.DataArray
9698
# rename the index to "obs_ind"
9799
data.index.name = "obs_ind"
98100
self.input_validation(data, treatment_time)

causalpy/experiments/prepostnegd.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def __init__(
9898
**kwargs,
9999
):
100100
super().__init__(model=model)
101+
self.causal_impact: xr.DataArray
102+
self.pred_xi: np.ndarray
103+
self.pred_untreated: az.InferenceData
104+
self.pred_treated: az.InferenceData
101105
self.data = data
102106
self.expt_type = "Pretest/posttest Nonequivalent Group Design"
103107
self.formula = formula
@@ -140,6 +144,7 @@ def __init__(
140144
else:
141145
raise ValueError("Model type not recognized")
142146

147+
assert self.model.idata is not None
143148
# Calculate the posterior predictive for the treatment and control for an
144149
# interpolated set of pretest values
145150
# get the model predictions of the observed data

causalpy/plot_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
import xarray as xr
2525
from matplotlib.collections import PolyCollection
2626
from matplotlib.lines import Line2D
27+
from pandas.api.extensions import ExtensionArray
2728

2829

2930
def plot_xY(
30-
x: Union[pd.DatetimeIndex, np.array],
31+
x: Union[pd.DatetimeIndex, np.ndarray, pd.Index, pd.Series, ExtensionArray],
3132
Y: xr.DataArray,
3233
ax: plt.Axes,
3334
plot_hdi_kwargs: Optional[Dict[str, Any]] = None,

causalpy/pymc_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class PyMCModel(pm.Model):
9191
Inference data...
9292
"""
9393

94-
default_priors = {}
94+
default_priors: Dict[str, Prior] = {}
9595

9696
def priors_from_data(self, X, y) -> Dict[str, Any]:
9797
"""
@@ -236,6 +236,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
236236
self.build_model(X, y, coords)
237237
with self:
238238
self.idata = pm.sample(**self.sample_kwargs)
239+
assert self.idata is not None
239240
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
240241
self.idata.extend(
241242
pm.sample_posterior_predictive(
@@ -349,6 +350,8 @@ def calculate_cumulative_impact(self, impact):
349350
return impact.cumsum(dim="obs_ind")
350351

351352
def print_coefficients(self, labels, round_to=None) -> None:
353+
assert self.idata is not None
354+
352355
def print_row(
353356
max_label_length: int, name: str, coeff_samples: xr.DataArray, round_to: int
354357
) -> None:

0 commit comments

Comments
 (0)