Skip to content

Commit cd565e7

Browse files
Change 'values' to 'value' in estimators.py
1 parent 237d80b commit cd565e7

File tree

3 files changed

+56
-58
lines changed

3 files changed

+56
-58
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
172172
causal_test_result = CausalTestResult(
173173
treatment=estimator.treatment,
174174
outcome=estimator.outcome,
175-
treatment_value=estimator.treatment_values,
176-
control_value=estimator.control_values,
175+
treatment_value=estimator.treatment_value,
176+
control_value=estimator.control_value,
177177
adjustment_set=estimator.adjustment_set,
178178
ate=cates_df,
179179
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
@@ -185,8 +185,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
185185
causal_test_result = CausalTestResult(
186186
treatment=estimator.treatment,
187187
outcome=estimator.outcome,
188-
treatment_value=estimator.treatment_values,
189-
control_value=estimator.control_values,
188+
treatment_value=estimator.treatment_value,
189+
control_value=estimator.control_value,
190190
adjustment_set=estimator.adjustment_set,
191191
ate=risk_ratio,
192192
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
@@ -198,8 +198,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
198198
causal_test_result = CausalTestResult(
199199
treatment=estimator.treatment,
200200
outcome=estimator.outcome,
201-
treatment_value=estimator.treatment_values,
202-
control_value=estimator.control_values,
201+
treatment_value=estimator.treatment_value,
202+
control_value=estimator.control_value,
203203
adjustment_set=estimator.adjustment_set,
204204
ate=ate,
205205
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
@@ -213,8 +213,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
213213
causal_test_result = CausalTestResult(
214214
treatment=estimator.treatment,
215215
outcome=estimator.outcome,
216-
treatment_value=estimator.treatment_values,
217-
control_value=estimator.control_values,
216+
treatment_value=estimator.treatment_value,
217+
control_value=estimator.control_value,
218218
adjustment_set=estimator.adjustment_set,
219219
ate=ate,
220220
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,

causal_testing/testing/estimators.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(
4242
effect_modifiers: dict[Variable:Any] = None,
4343
):
4444
self.treatment = treatment
45-
self.treatment_values = treatment_value
46-
self.control_values = control_value
45+
self.treatment_value = treatment_value
46+
self.control_value = control_value
4747
self.adjustment_set = adjustment_set
4848
self.outcome = outcome
4949
self.df = df
@@ -90,15 +90,15 @@ class LogisticRegressionEstimator(Estimator):
9090
def __init__(
9191
self,
9292
treatment: tuple,
93-
treatment_values: float,
94-
control_values: float,
93+
treatment_value: float,
94+
control_value: float,
9595
adjustment_set: set,
9696
outcome: tuple,
9797
df: pd.DataFrame = None,
9898
effect_modifiers: dict[Variable:Any] = None,
9999
intercept: int = 1,
100100
):
101-
super().__init__(treatment, treatment_values, control_values, adjustment_set, outcome, df, effect_modifiers)
101+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
102102

103103
for term in self.effect_modifiers:
104104
self.adjustment_set.add(term)
@@ -155,7 +155,7 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
155155
self.model = model
156156

157157
x = pd.DataFrame()
158-
x[self.treatment[0]] = [self.treatment_values, self.control_values]
158+
x[self.treatment[0]] = [self.treatment_value, self.control_value]
159159
x["Intercept"] = self.intercept
160160
for k, v in self.effect_modifiers.items():
161161
x[k] = v
@@ -212,16 +212,16 @@ class LinearRegressionEstimator(Estimator):
212212
def __init__(
213213
self,
214214
treatment: tuple,
215-
treatment_values: float,
216-
control_values: float,
215+
treatment_value: float,
216+
control_value: float,
217217
adjustment_set: set,
218218
outcome: tuple,
219219
df: pd.DataFrame = None,
220220
effect_modifiers: dict[Variable:Any] = None,
221221
product_terms: list[tuple[Variable, Variable]] = None,
222222
intercept: int = 1,
223223
):
224-
super().__init__(treatment, treatment_values, control_values, adjustment_set, outcome, df, effect_modifiers)
224+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
225225

226226
if product_terms is None:
227227
product_terms = []
@@ -304,7 +304,7 @@ def estimate_unit_ate(self) -> float:
304304
unit_effect = model.params[list(self.treatment)].values[0] # Unit effect is the coefficient of the treatment
305305
[ci_low, ci_high] = self._get_confidence_intervals(model)
306306

307-
return unit_effect * self.treatment_values - unit_effect * self.control_values, [ci_low, ci_high]
307+
return unit_effect * self.treatment_value - unit_effect * self.control_value, [ci_low, ci_high]
308308

309309
def estimate_ate(self) -> tuple[float, list[float, float], float]:
310310
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
@@ -315,8 +315,8 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
315315
model = self._run_linear_regression()
316316
# Create an empty individual for the control and treated
317317
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)
318-
individuals.loc["control", list(self.treatment)] = self.control_values
319-
individuals.loc["treated", list(self.treatment)] = self.treatment_values
318+
individuals.loc["control", list(self.treatment)] = self.control_value
319+
individuals.loc["treated", list(self.treatment)] = self.treatment_value
320320
# This is a temporary hack
321321
for t in self.square_terms:
322322
individuals[t + "^2"] = individuals[t] ** 2
@@ -338,7 +338,7 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
338338
self.model = model
339339

340340
x = pd.DataFrame()
341-
x[self.treatment[0]] = [self.treatment_values, self.control_values]
341+
x[self.treatment[0]] = [self.treatment_value, self.control_value]
342342
x["Intercept"] = self.intercept
343343
for k, v in self.effect_modifiers.items():
344344
x[k] = v
@@ -389,7 +389,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
389389
self.effect_modifiers
390390
), f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
391391
x = pd.DataFrame()
392-
x[self.treatment[0]] = [self.treatment_values, self.control_values]
392+
x[self.treatment[0]] = [self.treatment_value, self.control_value]
393393
x["Intercept"] = self.intercept
394394
for k, v in self.effect_modifiers.items():
395395
self.adjustment_set.add(k)
@@ -485,8 +485,8 @@ def estimate_ate(self) -> float:
485485
model.fit(outcome_df, treatment_df, X=effect_modifier_df, W=confounders_df)
486486

487487
# Obtain the ATE and 95% confidence intervals
488-
ate = model.ate(effect_modifier_df, T0=self.control_values, T1=self.treatment_values)
489-
ate_interval = model.ate_interval(effect_modifier_df, T0=self.control_values, T1=self.treatment_values)
488+
ate = model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
489+
ate_interval = model.ate_interval(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
490490
ci_low, ci_high = ate_interval[0], ate_interval[1]
491491
return ate, [ci_low, ci_high]
492492

@@ -525,9 +525,9 @@ def estimate_cates(self) -> pd.DataFrame:
525525
model.fit(outcome_df, treatment_df, X=effect_modifier_df, W=confounders_df)
526526

527527
# Obtain CATES and confidence intervals
528-
conditional_ates = model.effect(effect_modifier_df, T0=self.control_values, T1=self.treatment_values).flatten()
528+
conditional_ates = model.effect(effect_modifier_df, T0=self.control_value, T1=self.treatment_value).flatten()
529529
[ci_low, ci_high] = model.effect_interval(
530-
effect_modifier_df, T0=self.control_values, T1=self.treatment_values, alpha=0.05
530+
effect_modifier_df, T0=self.control_value, T1=self.treatment_value, alpha=0.05
531531
)
532532

533533
# Merge results into a dataframe (CATE, confidence intervals, and effect modifier values)

examples/poisson-line-process/causal_test_poisson.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from causal_testing.testing.causal_test_outcome import ExactValue, Positive
88
from causal_testing.testing.causal_test_engine import CausalTestEngine
99
from causal_testing.testing.estimators import LinearRegressionEstimator, Estimator
10+
from causal_testing.testing.base_test_case import BaseTestCase
1011

1112
import pandas as pd
1213

@@ -26,10 +27,10 @@ def estimate_ate(self) -> float:
2627
:return: The empirical average treatment effect.
2728
"""
2829
control_results = self.df.where(
29-
self.df[self.treatment[0]] == self.control_values
30+
self.df[self.treatment[0]] == self.control_value
3031
)[self.outcome].dropna()
3132
treatment_results = self.df.where(
32-
self.df[self.treatment[0]] == self.treatment_values
33+
self.df[self.treatment[0]] == self.treatment_value
3334
)[self.outcome].dropna()
3435
return treatment_results.mean()[0] - control_results.mean()[0], None
3536

@@ -38,10 +39,10 @@ def estimate_risk_ratio(self) -> float:
3839
:return: The empirical average treatment effect.
3940
"""
4041
control_results = self.df.where(
41-
self.df[self.treatment[0]] == self.control_values
42+
self.df[self.treatment[0]] == self.control_value
4243
)[self.outcome].dropna()
4344
treatment_results = self.df.where(
44-
self.df[self.treatment[0]] == self.treatment_values
45+
self.df[self.treatment[0]] == self.treatment_value
4546
)[self.outcome].dropna()
4647
return treatment_results.mean()[0] / control_results.mean()[0], None
4748

@@ -75,6 +76,10 @@ def estimate_risk_ratio(self) -> float:
7576
# 4. Construct a causal specification from the scenario and causal DAG
7677
causal_specification = CausalSpecification(scenario, causal_dag)
7778

79+
observational_data_path = "data/random/data_random_1000.csv"
80+
81+
intensity_num_shapes_results = []
82+
7883

7984
def test_intensity_num_shapes(
8085
observational_data_path,
@@ -92,24 +97,20 @@ def test_intensity_num_shapes(
9297
)
9398

9499
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
95-
causal_test_engine.identification(causal_test_case)
100+
minimal_adjustment_set = causal_dag.identification(causal_test_case.base_test_case)
96101

97102
# 9. Set up an estimator
98103
data = pd.read_csv(observational_data_path)
99104

100-
treatment = list(causal_test_case.control_input_configuration)[0].name
101-
outcome = list(causal_test_case.outcome_variables)[0].name
105+
treatment = causal_test_case.get_treatment_variable()
106+
outcome = causal_test_case.get_outcome_variable()
102107

103108
estimator = None
104109
if empirical:
105110
estimator = EmpiricalMeanEstimator(
106111
treatment=[treatment],
107-
control_values=list(causal_test_case.control_input_configuration.values())[
108-
0
109-
],
110-
treatment_values=list(
111-
causal_test_case.treatment_input_configuration.values()
112-
)[0],
112+
control_value=causal_test_case.control_value,
113+
treatment_value=causal_test_case.treatment_value,
113114
adjustment_set=set(),
114115
outcome=[outcome],
115116
df=data,
@@ -118,12 +119,8 @@ def test_intensity_num_shapes(
118119
else:
119120
estimator = LinearRegressionEstimator(
120121
treatment=[treatment],
121-
control_values=list(causal_test_case.control_input_configuration.values())[
122-
0
123-
],
124-
treatment_values=list(
125-
causal_test_case.treatment_input_configuration.values()
126-
)[0],
122+
control_value=causal_test_case.control_value,
123+
treatment_value=causal_test_case.treatment_value,
127124
adjustment_set=set(),
128125
outcome=[outcome],
129126
df=data,
@@ -143,9 +140,6 @@ def test_intensity_num_shapes(
143140
return causal_test_result
144141

145142

146-
observational_data_path = "data/random/data_random_1000.csv"
147-
148-
intensity_num_shapes_results = []
149143

150144
for wh in range(1, 11):
151145
smt_data_path = f"data/smt_100/data_smt_wh{wh}_100.csv"
@@ -154,14 +148,14 @@ def test_intensity_num_shapes(
154148
print(f"WIDTH = HEIGHT = {wh}")
155149

156150
print("Identifying")
157-
# 5. Create a causal test case
151+
base_test_case = BaseTestCase(treatment_variable=intensity,
152+
outcome_variable=num_shapes_unit)
158153
causal_test_case = CausalTestCase(
159-
control_input_configuration={intensity: control_value},
160-
treatment_input_configuration={intensity: treatment_value},
154+
base_test_case=base_test_case,
161155
expected_causal_effect=ExactValue(4, tolerance=0.5),
162-
outcome_variables={num_shapes_unit},
163-
estimate_type="risk_ratio",
164-
# effect_modifier_configuration={width: wh, height: wh}
156+
treatment_value=treatment_value,
157+
control_value=control_value,
158+
estimate_type="risk_ratio"
165159
)
166160
obs_causal_test_result = test_intensity_num_shapes(
167161
observational_data_path,
@@ -199,13 +193,17 @@ def test_intensity_num_shapes(
199193
# 5. Create a causal test case
200194
control_value = w
201195
treatment_value = w + 1
196+
base_test_case = BaseTestCase(
197+
treatment_variable=width,
198+
outcome_variable=num_shapes_unit
199+
)
202200
causal_test_case = CausalTestCase(
203-
control_input_configuration={width: control_value},
204-
treatment_input_configuration={width: treatment_value},
201+
base_test_case=base_test_case,
205202
expected_causal_effect=Positive(),
206-
outcome_variables={num_shapes_unit},
203+
control_value=control_value,
204+
treatment_value=treatment_value,
207205
estimate_type="ate_calculated",
208-
effect_modifier_configuration={intensity: i},
206+
effect_modifier_configuration={intensity: i}
209207
)
210208
causal_test_result = test_intensity_num_shapes(
211209
observational_data_path,

0 commit comments

Comments
 (0)