Skip to content

Commit 988ac05

Browse files
authored
Merge pull request #79 from juanitorduz/basic_ci
Basic Ci/CD and fix code style
2 parents 7e35f00 + 5a74cd8 commit 988ac05

16 files changed

+218
-82
lines changed

.github/workflows/ci.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: ci
2+
3+
on: [push]
4+
5+
jobs:
6+
lint:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
matrix:
10+
python-version: ["3.8", "3.9", "3.10"]
11+
12+
steps:
13+
- uses: actions/checkout@v3
14+
- name: Set up Python
15+
uses: actions/setup-python@v3
16+
with:
17+
python-version: ${{ matrix.python-version }}
18+
- name: Run lint
19+
run: |
20+
make init
21+
make check_lint
22+
test:
23+
runs-on: ubuntu-latest
24+
strategy:
25+
matrix:
26+
python-version: ["3.8", "3.9", "3.10"]
27+
28+
steps:
29+
- uses: actions/checkout@v3
30+
- name: Set up Python
31+
uses: actions/setup-python@v3
32+
with:
33+
python-version: ${{ matrix.python-version }}
34+
- name: Run tests
35+
run: |
36+
make init
37+
make test

.pre-commit-config.yaml

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
# See https://pre-commit.com for more information
22
# See https://pre-commit.com/hooks.html for more hooks
33
repos:
4-
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.3.0
6-
hooks:
7-
- id: trailing-whitespace
8-
exclude_types: [svg]
9-
- id: end-of-file-fixer
10-
exclude_types: [svg]
11-
- id: check-yaml
12-
- id: check-added-large-files
13-
- repo: https://github.com/asottile/seed-isort-config
14-
rev: v2.2.0
15-
hooks:
16-
- id: seed-isort-config
17-
- repo: https://github.com/pre-commit/mirrors-isort
18-
rev: v5.10.1
19-
hooks:
20-
- id: isort
21-
types: [python]
22-
- repo: https://github.com/ambv/black
23-
rev: 22.10.0
24-
hooks:
25-
- id: black
26-
- id: black-jupyter
4+
- repo: https://github.com/pre-commit/pre-commit-hooks
5+
rev: v4.3.0
6+
hooks:
7+
- id: trailing-whitespace
8+
exclude_types: [svg]
9+
- id: end-of-file-fixer
10+
exclude_types: [svg]
11+
- id: check-yaml
12+
- id: check-added-large-files
13+
- repo: https://github.com/asottile/seed-isort-config
14+
rev: v2.2.0
15+
hooks:
16+
- id: seed-isort-config
17+
- repo: https://github.com/pre-commit/mirrors-isort
18+
rev: v5.10.1
19+
hooks:
20+
- id: isort
21+
args: [--profile, black]
22+
types: [python]
23+
- repo: https://github.com/ambv/black
24+
rev: 22.10.0
25+
hooks:
26+
- id: black
27+
- id: black-jupyter
28+
- repo: https://github.com/pycqa/flake8
29+
rev: 3.9.2
30+
hooks:
31+
- id: flake8

Makefile

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
.PHONY: init lint check_lint test
2+
3+
init:
4+
python -m pip install -e .
5+
6+
lint:
7+
pip install -r requirements-lint.txt
8+
isort .
9+
black .
10+
11+
check_lint:
12+
pip install -r requirements-lint.txt
13+
flake8 .
14+
isort --check-only .
15+
black --diff --check --fast .
16+
17+
test:
18+
pip install -r requirements-test.txt
19+
pytest

causalpy/data/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import pathlib
32

43
import pandas as pd

causalpy/data/simulate_data.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def generate_synthetic_control_data(
2828
"""
2929
Example:
3030
>> import pathlib
31-
>> df, weightings_true = generate_synthetic_control_data(treatment_time=treatment_time)
31+
>> df, weightings_true = generate_synthetic_control_data(
32+
treatment_time=treatment_time
33+
)
3234
>> df.to_csv(pathlib.Path.cwd() / 'synthetic_control.csv', index=False)
3335
"""
3436

@@ -45,15 +47,17 @@ def generate_synthetic_control_data(
4547
}
4648
)
4749

48-
# 2. Generate counterfactual, based on weighted sum of non-treated variables. This is the counterfactual with NO treatment.
50+
# 2. Generate counterfactual, based on weighted sum of non-treated variables. This
51+
# is the counterfactual with NO treatment.
4952
weightings_true = dirichlet(np.ones(7)).rvs(1)
5053
df["counterfactual"] = np.dot(df.to_numpy(), weightings_true.T)
5154

5255
# 3. Generate the causal effect
5356
causal_effect = gamma(10).pdf(np.arange(0, N, 1) - treatment_time)
5457
df["causal effect"] = causal_effect * -50
5558

56-
# 4. Generate the actually observed data, ie the treated with the causal effect applied
59+
# 4. Generate the actually observed data, ie the treated with the causal effect
60+
# applied
5761
df["actual"] = df["counterfactual"] + df["causal effect"]
5862

5963
# 5. apply observation noise to all relevant variables
@@ -94,13 +98,7 @@ def generate_time_series_data(
9498
return df
9599

96100

97-
def generate_time_series_data(treatment_time):
98-
"""
99-
Example use:
100-
>> import pathlib
101-
>> df = generate_time_series_data("2017-01-01").loc[:, ['month', 'year', 't', 'y']]
102-
df.to_csv(pathlib.Path.cwd() / 'its.csv')
103-
"""
101+
def generate_time_series_data_seasonal(treatment_time):
104102
dates = pd.date_range(
105103
start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
106104
)
@@ -126,7 +124,9 @@ def generate_time_series_data(treatment_time):
126124

127125

128126
def generate_time_series_data_simple(treatment_time, slope=0.0):
129-
"""Generate simple interrupted time series data, with no seasonality or temporal structure"""
127+
"""Generate simple interrupted time series data, with no seasonality or temporal
128+
structure.
129+
"""
130130
dates = pd.date_range(
131131
start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
132132
)

causalpy/pymc_experiments.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,28 @@ def print_coefficients(self):
2828
"""Prints the model coefficients"""
2929
print("Model coefficients:")
3030
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
31-
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
31+
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of
32+
# the stats despite variable names of different lengths
3233
for name in self.labels:
3334
coeff_samples = coeffs.sel(coeffs=name)
3435
print(
35-
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
36+
f"""
37+
{name: <30}{coeff_samples.mean().data:.2f},
38+
94% HDI [{coeff_samples.quantile(0.03).data:.2f},
39+
{coeff_samples.quantile(1-0.03).data:.2f}]
40+
"""
3641
)
3742
# add coeff for measurement std
3843
coeff_samples = az.extract(
3944
self.prediction_model.idata.posterior, var_names="sigma"
4045
)
4146
name = "sigma"
4247
print(
43-
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
48+
f"""
49+
{name: <30}{coeff_samples.mean().data:.2f},
50+
94% HDI [{coeff_samples.quantile(0.03).data:.2f},
51+
{coeff_samples.quantile(1-0.03).data:.2f}]
52+
"""
4453
)
4554

4655

@@ -121,8 +130,12 @@ def plot(self):
121130
include_label=False,
122131
)
123132
ax[0].plot(self.datapost.index, self.post_y, "k.")
133+
124134
ax[0].set(
125-
title=f"Pre-intervention Bayesian $R^2$: {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
135+
title=f"""
136+
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
137+
(std = {self.score.r2_std:.3f})
138+
"""
126139
)
127140

128141
plot_xY(self.datapre.index, self.pre_impact, ax=ax[1])
@@ -198,7 +211,8 @@ class DifferenceInDifferences(ExperimentalDesign):
198211
199212
.. note::
200213
201-
There is no pre/post intervention data distinction for DiD, we fit all the data available.
214+
There is no pre/post intervention data distinction for DiD, we fit all the
215+
data available.
202216
203217
"""
204218

@@ -239,16 +253,26 @@ def __init__(
239253
assert (
240254
"treated" in self.data.columns
241255
), "Require a boolean column labelling observations which are `treated`"
242-
# Check for `unit` in the incoming dataframe. *This is only used for plotting purposes*
256+
# Check for `unit` in the incoming dataframe.
257+
# *This is only used for plotting purposes*
243258
assert (
244259
"unit" in self.data.columns
245-
), "Require a `unit` column to label unique units. This is used for plotting purposes"
246-
# Check that `group_variable_name` has TWO levels, representing the treated/untreated. But it does not matter what the actual names of the levels are.
260+
), """
261+
Require a `unit` column to label unique units.
262+
This is used for plotting purposes
263+
"""
264+
# Check that `group_variable_name` has TWO levels, representing the
265+
# treated/untreated. But it does not matter what the actual names of
266+
# the levels are.
247267
assert (
248-
len(pd.Categorical(self.data[self.group_variable_name]).categories) is 2
249-
), f"There must be 2 levels of the grouping variable {self.group_variable_name}. I.e. the treated and untreated."
268+
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
269+
), f"""
270+
There must be 2 levels of the grouping variable {self.group_variable_name}
271+
.I.e. the treated and untreated.
272+
"""
250273

251-
# TODO: `treated` is a deterministic function of group and time, so this could be a function rather than supplied data
274+
# TODO: `treated` is a deterministic function of group and time, so this could
275+
# be a function rather than supplied data
252276

253277
# DEVIATION FROM SKL EXPERIMENT CODE =============================
254278
# fit the model to the observed (pre-intervention) data
@@ -348,11 +372,13 @@ def plot(self):
348372
showmedians=False,
349373
widths=0.2,
350374
)
375+
351376
for pc in parts["bodies"]:
352377
pc.set_facecolor("C1")
353378
pc.set_edgecolor("None")
354379
pc.set_alpha(0.5)
355-
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
380+
# Plot counterfactual - post-test for treatment group IF no treatment
381+
# had occurred.
356382
parts = ax.violinplot(
357383
az.extract(
358384
self.y_pred_counterfactual,
@@ -380,7 +406,8 @@ def plot(self):
380406

381407
def _plot_causal_impact_arrow(self, ax):
382408
"""
383-
draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
409+
draw a vertical arrow between `y_pred_counterfactual` and
410+
`y_pred_counterfactual`
384411
"""
385412
# Calculate y values to plot the arrow between
386413
y_pred_treatment = (
@@ -438,13 +465,16 @@ class RegressionDiscontinuity(ExperimentalDesign):
438465
439466
:param data: A pandas dataframe
440467
:param formula: A statistical model formula
441-
:param treatment_threshold: A scalar threshold value at which the treatment is applied
468+
:param treatment_threshold: A scalar threshold value at which the treatment
469+
is applied
442470
:param prediction_model: A PyMC model
443-
:param running_variable_name: The name of the predictor variable that the treatment threshold is based upon
471+
:param running_variable_name: The name of the predictor variable that the treatment
472+
threshold is based upon
444473
445474
.. note::
446475
447-
There is no pre/post intervention data distinction for the regression discontinuity design, we fit all the data available.
476+
There is no pre/post intervention data distinction for the regression
477+
discontinuity design, we fit all the data available.
448478
"""
449479

450480
def __init__(
@@ -469,7 +499,8 @@ def __init__(
469499
self.y, self.X = np.asarray(y), np.asarray(X)
470500
self.outcome_variable_name = y.design_info.column_names[0]
471501

472-
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
502+
# TODO: `treated` is a deterministic function of x and treatment_threshold, so
503+
# this could be a function rather than supplied data
473504

474505
# DEVIATION FROM SKL EXPERIMENT CODE =============================
475506
# fit the model to the observed (pre-intervention) data
@@ -492,8 +523,10 @@ def __init__(
492523
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
493524
self.pred = self.prediction_model.predict(X=np.asarray(new_x))
494525

495-
# calculate discontinuity by evaluating the difference in model expectation on either side of the discontinuity
496-
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above (not below) the threshold
526+
# calculate discontinuity by evaluating the difference in model expectation on
527+
# either side of the discontinuity
528+
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above
529+
# (not below) the threshold
497530
self.x_discon = pd.DataFrame(
498531
{
499532
self.running_variable_name: np.array(
@@ -514,7 +547,7 @@ def _is_treated(self, x):
514547
515548
.. warning::
516549
517-
Assumes treatment is given to those ABOVE the treatment threshold.
550+
Assumes treatment is given to those ABOVE the treatment threshold.
518551
"""
519552
return np.greater_equal(x, self.treatment_threshold)
520553

@@ -536,10 +569,13 @@ def plot(self):
536569
ax=ax,
537570
)
538571
# create strings to compose title
539-
r2 = f"Bayesian $R^2$ on all data = {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
572+
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
573+
r2 = f"Bayesian $R^2$ on all data = {title_info}"
540574
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
541575
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
542-
discon = f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}, "
576+
discon = f"""
577+
Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f},
578+
"""
543579
ax.set(title=r2 + "\n" + discon + ci)
544580
# Intervention line
545581
ax.axvline(
@@ -559,7 +595,7 @@ def summary(self):
559595
print(f"Formula: {self.formula}")
560596
print(f"Running variable: {self.running_variable_name}")
561597
print(f"Threshold on running variable: {self.treatment_threshold}")
562-
print(f"\nResults:")
598+
print("\nResults:")
563599
print(
564600
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
565601
)

causalpy/pymc_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def _data_setter(self, X):
2121
pm.set_data({"X": X})
2222

2323
def fit(self, X, y, coords):
24-
"""Draw samples from posterior, prior predictive, and posterior predictive distributions."""
24+
"""Draw samples from posterior, prior predictive, and posterior predictive
25+
distributions.
26+
"""
2527
self.build_model(X, y, coords)
2628
with self.model:
2729
self.idata = pm.sample()
@@ -43,7 +45,8 @@ def score(self, X, y):
4345
4446
.. caution::
4547
46-
The Bayesian :math:`R^2` is not the same as the traditional coefficient of determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
48+
The Bayesian :math:`R^2` is not the same as the traditional coefficient of
49+
determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
4750
4851
"""
4952
yhat = self.predict(X)

0 commit comments

Comments
 (0)