Skip to content

Commit 07df35f

Browse files
author
juanitorduz
committed
fix code style
1 parent f1c747f commit 07df35f

File tree

6 files changed

+89
-41
lines changed

6 files changed

+89
-41
lines changed

causalpy/data/simulate_data.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def generate_synthetic_control_data(
2828
"""
2929
Example:
3030
>> import pathlib
31-
>> df, weightings_true = generate_synthetic_control_data(treatment_time=treatment_time)
31+
>> df, weightings_true = generate_synthetic_control_data(
32+
treatment_time=treatment_time
33+
)
3234
>> df.to_csv(pathlib.Path.cwd() / 'synthetic_control.csv', index=False)
3335
"""
3436

@@ -45,15 +47,17 @@ def generate_synthetic_control_data(
4547
}
4648
)
4749

48-
# 2. Generate counterfactual, based on weighted sum of non-treated variables. This is the counterfactual with NO treatment.
50+
# 2. Generate counterfactual, based on weighted sum of non-treated variables. This
51+
# is the counterfactual with NO treatment.
4952
weightings_true = dirichlet(np.ones(7)).rvs(1)
5053
df["counterfactual"] = np.dot(df.to_numpy(), weightings_true.T)
5154

5255
# 3. Generate the causal effect
5356
causal_effect = gamma(10).pdf(np.arange(0, N, 1) - treatment_time)
5457
df["causal effect"] = causal_effect * -50
5558

56-
# 4. Generate the actually observed data, ie the treated with the causal effect applied
59+
# 4. Generate the actually observed data, ie the treated with the causal effect
60+
# applied
5761
df["actual"] = df["counterfactual"] + df["causal effect"]
5862

5963
# 5. apply observation noise to all relevant variables
@@ -126,7 +130,9 @@ def generate_time_series_data(treatment_time):
126130

127131

128132
def generate_time_series_data_simple(treatment_time, slope=0.0):
129-
"""Generate simple interrupted time series data, with no seasonality or temporal structure"""
133+
"""Generate simple interrupted time series data, with no seasonality or temporal
134+
structure.
135+
"""
130136
dates = pd.date_range(
131137
start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
132138
)

causalpy/pymc_experiments.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,28 @@ def print_coefficients(self):
2828
"""Prints the model coefficients"""
2929
print("Model coefficients:")
3030
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
31-
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
31+
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of
32+
# the stats despite variable names of different lengths
3233
for name in self.labels:
3334
coeff_samples = coeffs.sel(coeffs=name)
3435
print(
35-
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
36+
f"""
37+
{name: <30}{coeff_samples.mean().data:.2f},
38+
94% HDI [{coeff_samples.quantile(0.03).data:.2f},
39+
{coeff_samples.quantile(1-0.03).data:.2f}]
40+
"""
3641
)
3742
# add coeff for measurement std
3843
coeff_samples = az.extract(
3944
self.prediction_model.idata.posterior, var_names="sigma"
4045
)
4146
name = "sigma"
4247
print(
43-
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
48+
f"""
49+
{name: <30}{coeff_samples.mean().data:.2f},
50+
94% HDI [{coeff_samples.quantile(0.03).data:.2f},
51+
{coeff_samples.quantile(1-0.03).data:.2f}]
52+
"""
4453
)
4554

4655

@@ -120,8 +129,12 @@ def plot(self):
120129
include_label=False,
121130
)
122131
ax[0].plot(self.datapost.index, self.post_y, "k.")
132+
123133
ax[0].set(
124-
title=f"Pre-intervention Bayesian $R^2$: {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
134+
title=f"""
135+
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
136+
(std = {self.score.r2_std:.3f})
137+
"""
125138
)
126139

127140
plot_xY(self.datapre.index, self.pre_impact, ax=ax[1])
@@ -200,7 +213,8 @@ class DifferenceInDifferences(ExperimentalDesign):
200213
201214
.. note::
202215
203-
There is no pre/post intervention data distinction for DiD, we fit all the data available.
216+
There is no pre/post intervention data distinction for DiD, we fit all the
217+
data available.
204218
205219
"""
206220

@@ -224,7 +238,8 @@ def __init__(
224238
self.y, self.X = np.asarray(y), np.asarray(X)
225239
self.outcome_variable_name = y.design_info.column_names[0]
226240

227-
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
241+
# TODO: `treated` is a deterministic function of group and time, so this should
242+
# be a function rather than supplied data
228243

229244
# DEVIATION FROM SKL EXPERIMENT CODE =============================
230245
# fit the model to the observed (pre-intervention) data
@@ -309,7 +324,8 @@ def plot(self):
309324
showmedians=False,
310325
widths=0.2,
311326
)
312-
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
327+
# Plot counterfactual - post-test for treatment group IF no treatment
328+
# had occurred.
313329
parts = ax.violinplot(
314330
az.extract(
315331
self.y_pred_counterfactual,
@@ -381,13 +397,16 @@ class RegressionDiscontinuity(ExperimentalDesign):
381397
382398
:param data: A pandas dataframe
383399
:param formula: A statistical model formula
384-
:param treatment_threshold: A scalar threshold value at which the treatment is applied
400+
:param treatment_threshold: A scalar threshold value at which the treatment
401+
is applied
385402
:param prediction_model: A PyMC model
386-
:param running_variable_name: The name of the predictor variable that the treatment threshold is based upon
403+
:param running_variable_name: The name of the predictor variable that the treatment
404+
threshold is based upon
387405
388406
.. note::
389407
390-
There is no pre/post intervention data distinction for the regression discontinuity design, we fit all the data available.
408+
There is no pre/post intervention data distinction for the regression
409+
discontinuity design, we fit all the data available.
391410
"""
392411

393412
def __init__(
@@ -412,7 +431,8 @@ def __init__(
412431
self.y, self.X = np.asarray(y), np.asarray(X)
413432
self.outcome_variable_name = y.design_info.column_names[0]
414433

415-
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
434+
# TODO: `treated` is a deterministic function of x and treatment_threshold, so
435+
# this could be a function rather than supplied data
416436

417437
# DEVIATION FROM SKL EXPERIMENT CODE =============================
418438
# fit the model to the observed (pre-intervention) data
@@ -435,8 +455,10 @@ def __init__(
435455
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
436456
self.pred = self.prediction_model.predict(X=np.asarray(new_x))
437457

438-
# calculate discontinuity by evaluating the difference in model expectation on either side of the discontinuity
439-
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above (not below) the threshold
458+
# calculate discontinuity by evaluating the difference in model expectation on
459+
# either side of the discontinuity
460+
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above
461+
# (not below) the threshold
440462
self.x_discon = pd.DataFrame(
441463
{
442464
self.running_variable_name: np.array(
@@ -457,7 +479,7 @@ def _is_treated(self, x):
457479
458480
.. warning::
459481
460-
Assumes treatment is given to those ABOVE the treatment threshold.
482+
Assumes treatment is given to those ABOVE the treatment threshold.
461483
"""
462484
return np.greater_equal(x, self.treatment_threshold)
463485

@@ -479,10 +501,13 @@ def plot(self):
479501
ax=ax,
480502
)
481503
# create strings to compose title
482-
r2 = f"Bayesian $R^2$ on all data = {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
504+
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
505+
r2 = f"Bayesian $R^2$ on all data = {title_info}"
483506
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
484507
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
485-
discon = f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}, "
508+
discon = f"""
509+
Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f},
510+
"""
486511
ax.set(title=r2 + "\n" + discon + ci)
487512
# Intervention line
488513
ax.axvline(
@@ -502,7 +527,7 @@ def summary(self):
502527
print(f"Formula: {self.formula}")
503528
print(f"Running variable: {self.running_variable_name}")
504529
print(f"Threshold on running variable: {self.treatment_threshold}")
505-
print(f"\nResults:")
530+
print("\nResults:")
506531
print(
507532
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
508533
)

causalpy/pymc_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def _data_setter(self, X):
2121
pm.set_data({"X": X})
2222

2323
def fit(self, X, y, coords):
24-
"""Draw samples from posterior, prior predictive, and posterior predictive distributions."""
24+
"""Draw samples from posterior, prior predictive, and posterior predictive
25+
distributions.
26+
"""
2527
self.build_model(X, y, coords)
2628
with self.model:
2729
self.idata = pm.sample()
@@ -43,7 +45,8 @@ def score(self, X, y):
4345
4446
.. caution::
4547
46-
The Bayesian :math:`R^2` is not the same as the traditional coefficient of determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
48+
The Bayesian :math:`R^2` is not the same as the traditional coefficient of
49+
determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
4750
4851
"""
4952
yhat = self.predict(X)

causalpy/skl_experiments.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ def plot_coeffs(self):
141141
)
142142

143143

144-
# InterruptedTimeSeries and SyntheticControl are basically the same thing but with different
145-
# predictor variables. So we just have a TimeSeriesExperiment class and InterruptedTimeSeries
146-
# and SyntheticControl are both equal to the TimeSeriesExperiment class
144+
# InterruptedTimeSeries and SyntheticControl are basically the same thing but with
145+
# different predictor variables. So we just have a TimeSeriesExperiment class and
146+
# InterruptedTimeSeries and SyntheticControl are both equal to the TimeSeriesExperiment
147+
# class
147148

148149

149150
class InterruptedTimeSeries(TimeSeriesExperiment):
@@ -168,7 +169,8 @@ class DifferenceInDifferences(ExperimentalDesign):
168169
"""
169170
.. note::
170171
171-
There is no pre/post intervention data distinction for DiD, we fit all the data available.
172+
There is no pre/post intervention data distinction for DiD, we fit all the data
173+
available.
172174
"""
173175

174176
def __init__(
@@ -190,7 +192,8 @@ def __init__(
190192
self.y, self.X = np.asarray(y), np.asarray(X)
191193
self.outcome_variable_name = y.design_info.column_names[0]
192194

193-
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
195+
# TODO: `treated` is a deterministic function of group and time, so this should
196+
# be a function rather than supplied data
194197

195198
# fit the model to all the data
196199
self.prediction_model.fit(X=self.X, y=self.y)
@@ -254,7 +257,8 @@ def plot(self):
254257
markersize=10,
255258
label="model fit (treament group)",
256259
)
257-
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
260+
# Plot counterfactual - post-test for treatment group IF no treatment
261+
# had occurred.
258262
ax.plot(
259263
self.x_pred_counterfactual[self.time_variable_name],
260264
self.y_pred_counterfactual,
@@ -297,7 +301,8 @@ class RegressionDiscontinuity(ExperimentalDesign):
297301
298302
.. note::
299303
300-
There is no pre/post intervention data distinction for the regression discontinuity design, we fit all the data available.
304+
There is no pre/post intervention data distinction for the regression
305+
discontinuity design, we fit all the data available.
301306
302307
"""
303308

@@ -322,7 +327,8 @@ def __init__(
322327
self.y, self.X = np.asarray(y), np.asarray(X)
323328
self.outcome_variable_name = y.design_info.column_names[0]
324329

325-
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
330+
# TODO: `treated` is a deterministic function of x and treatment_threshold, so
331+
# this could be a function rather than supplied data
326332

327333
# fit the model to all the data
328334
self.prediction_model.fit(X=self.X, y=self.y)
@@ -342,8 +348,10 @@ def __init__(
342348
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
343349
self.pred = self.prediction_model.predict(X=np.asarray(new_x))
344350

345-
# calculate discontinuity by evaluating the difference in model expectation on either side of the discontinuity
346-
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above (not below) the threshold
351+
# calculate discontinuity by evaluating the difference in model expectation on
352+
# either side of the discontinuity
353+
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above
354+
# (not below) the threshold
347355
self.x_discon = pd.DataFrame(
348356
{
349357
self.running_variable_name: np.array(
@@ -359,11 +367,12 @@ def __init__(
359367
)
360368

361369
def _is_treated(self, x):
362-
"""Returns ``True`` if ``x`` is greater than or equal to the treatment threshold.
370+
"""Returns ``True`` if ``x`` is greater than or equal to the treatment
371+
threshold.
363372
364373
.. warning::
365374
366-
Assumes treatment is given to those ABOVE the treatment threshold.
375+
Assumes treatment is given to those ABOVE the treatment threshold.
367376
"""
368377
return np.greater_equal(x, self.treatment_threshold)
369378

@@ -406,7 +415,7 @@ def summary(self):
406415
print(f"Formula: {self.formula}")
407416
print(f"Running variable: {self.running_variable_name}")
408417
print(f"Threshold on running variable: {self.treatment_threshold}")
409-
print(f"\nResults:")
418+
print("\nResults:")
410419
print(f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}")
411420
print("Model coefficients:")
412421
for name, val in zip(self.labels, self.prediction_model.coef_[0]):

docs/conf.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import os
1313
import sys
1414

15+
from causalpy.version import __version__
16+
1517
sys.path.insert(0, os.path.abspath("../"))
1618

1719
# autodoc_mock_imports
1820
# This avoids autodoc breaking when it can't find packages imported in the code.
19-
# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports
21+
# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports # noqa: E501
2022
autodoc_mock_imports = [
2123
"arviz",
2224
"matplotlib",
@@ -37,7 +39,6 @@
3739
copyright = "2022, Benjamin T. Vincent"
3840
author = "Benjamin T. Vincent"
3941

40-
from causalpy.version import __version__
4142

4243
release = __version__
4344

@@ -57,10 +58,13 @@
5758
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
5859

5960
# -- nbsphinx config ----------------------------------------------------------
60-
# Opt out of executing the notebooks remotely. This will save time in the remote build process on readthedocs. The notebooks in /docs/notebooks will be parsed/converted, but not re-executed.
61+
# Opt out of executing the notebooks remotely. This will save time in the remote build
62+
# process on readthedocs. The notebooks in /docs/notebooks will be parsed/converted,
63+
# but not re-executed.
6164
nbsphinx_execute = "never"
6265

63-
# MyST options for working with markdown files. Info about extensions here https://myst-parser.readthedocs.io/en/latest/syntax/optional.html?highlight=math#admonition-directives
66+
# MyST options for working with markdown files.
67+
# Info about extensions here https://myst-parser.readthedocs.io/en/latest/syntax/optional.html?highlight=math#admonition-directives # noqa: E501
6468
myst_enable_extensions = ["dollarmath", "amsmath", "colon_fence", "linkify"]
6569

6670
# -- Options for HTML output -------------------------------------------------

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from setuptools import find_packages, setup
55

6+
from causalpy.version import __version__
7+
68
PROJECT_ROOT = os.path.dirname(os.path.realpath(__file__))
79
README_FILE = os.path.join(PROJECT_ROOT, "README.md")
810
VERSION_FILE = os.path.join(PROJECT_ROOT, "bambi", "version.py")
@@ -16,7 +18,6 @@ def get_long_description():
1618

1719
# get version
1820
sys.path.insert(0, os.path.abspath("../"))
19-
from causalpy.version import __version__
2021

2122
with open(REQUIREMENTS_FILE) as f:
2223
install_reqs = f.read().splitlines()

0 commit comments

Comments
 (0)