Skip to content

Commit 9501440

Browse files
committed
#69 remove all instances of manually defining outcome_variable_name
1 parent 4cd357e commit 9501440

File tree

4 files changed

+14
-16
lines changed

4 files changed

+14
-16
lines changed

causalpy/pymc_experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646

4747
# set things up with pre-intervention data
4848
y, X = dmatrices(formula, self.datapre)
49+
self.outcome_variable_name = y.design_info.column_names[0]
4950
self._y_design_info = y.design_info
5051
self._x_design_info = X.design_info
5152
self.labels = X.design_info.column_names
@@ -179,20 +180,19 @@ def __init__(
179180
data,
180181
formula,
181182
time_variable_name="t",
182-
outcome_variable_name="y",
183183
prediction_model=None,
184184
**kwargs,
185185
):
186186
super().__init__(prediction_model=prediction_model, **kwargs)
187187
self.data = data
188188
self.formula = formula
189189
self.time_variable_name = time_variable_name
190-
self.outcome_variable_name = outcome_variable_name
191190
y, X = dmatrices(formula, self.data)
192191
self._y_design_info = y.design_info
193192
self._x_design_info = X.design_info
194193
self.labels = X.design_info.column_names
195194
self.y, self.X = np.asarray(y), np.asarray(X)
195+
self.outcome_variable_name = y.design_info.column_names[0]
196196

197197
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
198198

causalpy/skl_experiments.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class ExperimentalDesign:
1111
"""Base class for experiment designs"""
1212

1313
prediction_model = None
14+
outcome_variable_name = None
1415

1516
def __init__(self, prediction_model=None, **kwargs):
1617
if prediction_model is not None:
@@ -34,6 +35,7 @@ def __init__(self, data, treatment_time, formula, prediction_model=None, **kwarg
3435
self._y_design_info = y.design_info
3536
self._x_design_info = X.design_info
3637
self.labels = X.design_info.column_names
38+
self.outcome_variable_name = y.design_info.column_names[0]
3739
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
3840
# process post-intervention data
3941
(new_y, new_x) = build_design_matrices(
@@ -174,20 +176,19 @@ def __init__(
174176
data,
175177
formula,
176178
time_variable_name="t",
177-
outcome_variable_name="y",
178179
prediction_model=None,
179180
**kwargs,
180181
):
181182
super().__init__(prediction_model=prediction_model, **kwargs)
182183
self.data = data
183184
self.formula = formula
184185
self.time_variable_name = time_variable_name
185-
self.outcome_variable_name = outcome_variable_name
186186
y, X = dmatrices(formula, self.data)
187187
self._y_design_info = y.design_info
188188
self._x_design_info = X.design_info
189189
self.labels = X.design_info.column_names
190190
self.y, self.X = np.asarray(y), np.asarray(X)
191+
self.outcome_variable_name = y.design_info.column_names[0]
191192

192193
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
193194

@@ -307,20 +308,19 @@ def __init__(
307308
treatment_threshold,
308309
prediction_model=None,
309310
running_variable_name="x",
310-
outcome_variable_name="y",
311311
**kwargs,
312312
):
313313
super().__init__(prediction_model=prediction_model, **kwargs)
314314
self.data = data
315315
self.formula = formula
316316
self.running_variable_name = running_variable_name
317-
self.outcome_variable_name = outcome_variable_name
318317
self.treatment_threshold = treatment_threshold
319318
y, X = dmatrices(formula, self.data)
320319
self._y_design_info = y.design_info
321320
self._x_design_info = X.design_info
322321
self.labels = X.design_info.column_names
323322
self.y, self.X = np.asarray(y), np.asarray(X)
323+
self.outcome_variable_name = y.design_info.column_names[0]
324324

325325
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
326326

docs/notebooks/rd_pymc.ipynb

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

docs/notebooks/rd_skl_drinking.ipynb

Lines changed: 5 additions & 7 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)