Skip to content

Commit 4115ef3

Browse files
author
juanitorduz
committed
fix code style
1 parent 56a2ddd commit 4115ef3

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

causalpy/data/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import pathlib
32

43
import pandas as pd

causalpy/pymc_experiments.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ def __init__(
244244
self.y, self.X = np.asarray(y), np.asarray(X)
245245
self.outcome_variable_name = y.design_info.column_names[0]
246246

247-
248247
# Input validation ----------------------------------------------------
249248
# Check that `treated` appears in the module formula
250249
assert (
@@ -254,17 +253,26 @@ def __init__(
254253
assert (
255254
"treated" in self.data.columns
256255
), "Require a boolean column labelling observations which are `treated`"
257-
# Check for `unit` in the incoming dataframe. *This is only used for plotting purposes*
256+
# Check for `unit` in the incoming dataframe.
257+
# *This is only used for plotting purposes*
258258
assert (
259259
"unit" in self.data.columns
260-
), "Require a `unit` column to label unique units. This is used for plotting purposes"
261-
# Check that `group_variable_name` has TWO levels, representing the treated/untreated.
262-
# But it does not matter what the actual names of the levels are.
260+
), """
261+
Require a `unit` column to label unique units.
262+
This is used for plotting purposes
263+
"""
264+
# Check that `group_variable_name` has TWO levels, representing the
265+
# treated/untreated. But it does not matter what the actual names of
266+
# the levels are.
263267
assert (
264-
len(pd.Categorical(self.data[self.group_variable_name]).categories) is 2
265-
), f"There must be 2 levels of the grouping variable {self.group_variable_name}.I.e. the treated and untreated."
268+
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
269+
), f"""
270+
There must be 2 levels of the grouping variable {self.group_variable_name}
271+
.I.e. the treated and untreated.
272+
"""
266273

267-
# TODO: `treated` is a deterministic function of group and time, so this could be a function rather than supplied data
274+
# TODO: `treated` is a deterministic function of group and time, so this could
275+
# be a function rather than supplied data
268276

269277
# DEVIATION FROM SKL EXPERIMENT CODE =============================
270278
# fit the model to the observed (pre-intervention) data
@@ -369,7 +377,8 @@ def plot(self):
369377
pc.set_facecolor("C1")
370378
pc.set_edgecolor("None")
371379
pc.set_alpha(0.5)
372-
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
380+
# Plot counterfactual - post-test for treatment group IF no treatment
381+
# had occurred.
373382
parts = ax.violinplot(
374383
az.extract(
375384
self.y_pred_counterfactual,
@@ -397,7 +406,8 @@ def plot(self):
397406

398407
def _plot_causal_impact_arrow(self, ax):
399408
"""
400-
draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
409+
draw a vertical arrow between `y_pred_counterfactual` and
410+
`y_pred_counterfactual`
401411
"""
402412
# Calculate y values to plot the arrow between
403413
y_pred_treatment = (

0 commit comments

Comments
 (0)