Skip to content

Commit 36d6e1d

Browse files
committed
#53 adding more docstring info
1 parent 032d316 commit 36d6e1d

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

causalpy/pymc_experiments.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@ def __init__(self, prediction_model=None, **kwargs):
2626
class TimeSeriesExperiment(ExperimentalDesign):
2727
"""A class to analyse time series quasi-experiments"""
2828

29-
def __init__(self, data, treatment_time, formula, prediction_model=None, **kwargs):
29+
def __init__(
30+
self,
31+
data: pd.DataFrame,
32+
treatment_time: int,
33+
formula: str,
34+
prediction_model=None,
35+
**kwargs,
36+
) -> None:
3037
super().__init__(prediction_model=prediction_model, **kwargs)
3138
self.treatment_time = treatment_time
3239
# split data in to pre and post intervention
@@ -111,6 +118,8 @@ def plot(self):
111118

112119

113120
class SyntheticControl(TimeSeriesExperiment):
121+
"""A wrapper around the TimeSeriesExperiment class"""
122+
114123
def plot(self):
115124
"""Plot the results"""
116125
fig, ax = super().plot()
@@ -121,6 +130,8 @@ def plot(self):
121130

122131

123132
class InterruptedTimeSeries(TimeSeriesExperiment):
133+
"""A wrapper around the TimeSeriesExperiment class"""
134+
124135
pass
125136

126137

causalpy/pymc_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def predict(self, X):
3939
return post_pred
4040

4141
def score(self, X, y):
42-
"""Score the predictions $R^2$ given inputs X and outputs y."""
42+
"""Score the predictions :math:`R^2` given inputs ``X`` and outputs ``y``."""
4343
yhat = self.predict(X)
4444
yhat = az.extract(yhat, group="posterior_predictive", var_names="y_hat").mean(
4545
dim="sample"

causalpy/skl_experiments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def __init__(
359359
)
360360

361361
def _is_treated(self, x):
362-
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
362+
"""Returns ``True`` if ``x`` is greater than or equal to the treatment threshold.
363363
364364
.. warning::
365365
@@ -401,6 +401,7 @@ def plot(self):
401401
return (fig, ax)
402402

403403
def summary(self):
404+
"""Print text output summarising the results"""
404405
print("Difference in Differences experiment")
405406
print(f"Formula: {self.formula}")
406407
print(f"Running variable: {self.running_variable_name}")

0 commit comments

Comments
 (0)