Skip to content

Commit d8439a7

Browse files
Merge branch 'main' into estimator_params_for_linear
2 parents ef094af + ee8be4d commit d8439a7

File tree

5 files changed

+8
-34
lines changed

5 files changed

+8
-34
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,6 @@ def __init__(
5757
else:
5858
self.effect_modifier_configuration = {}
5959

60-
def get_treatment_variable(self):
61-
"""Return the treatment variable name (as string) for this causal test case"""
62-
return self.treatment_variable.name
63-
64-
def get_outcome_variable(self):
65-
"""Return the outcome variable name (as string) for this causal test case."""
66-
return self.outcome_variable.name
67-
68-
def get_control_value(self):
69-
"""Return a the control value of the treatment variable in this causal test case."""
70-
return self.control_value
71-
72-
def get_treatment_value(self):
73-
"""Return the treatment value of the treatment variable in this causal test case."""
74-
return self.treatment_value
75-
7660
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
7761
"""Execute a causal test case and return the causal test result.
7862

causal_testing/testing/causal_test_result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def to_dict(self, json=False):
8585
"outcome": self.estimator.outcome,
8686
"adjustment_set": list(self.adjustment_set) if json else self.adjustment_set,
8787
"effect_measure": self.test_value.type,
88-
"effect_estimate": self.test_value.value,
89-
"ci_low": self.ci_low(),
90-
"ci_high": self.ci_high(),
88+
"effect_estimate": self.test_value.value.to_dict()
89+
if json and hasattr(self.test_value.value, "to_dict")
90+
else self.test_value.value,
91+
"ci_low": self.ci_low().to_dict() if json and hasattr(self.ci_low(), "to_dict") else self.ci_low(),
92+
"ci_high": self.ci_high().to_dict() if json and hasattr(self.ci_high(), "to_dict") else self.ci_high(),
9193
}
9294
if self.adequacy:
9395
base_dict["adequacy"] = self.adequacy.to_dict()

examples/poisson-line-process/example_poisson_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def causal_test_intensity_num_shapes(
9090
# 8. Set up an estimator
9191
data = pd.read_csv(observational_data_path)
9292

93-
treatment = causal_test_case.get_treatment_variable()
94-
outcome = causal_test_case.get_outcome_variable()
93+
treatment = causal_test_case.treatment_variable.name
94+
outcome = causal_test_case.outcome_variable.name
9595

9696
estimator = None
9797
if empirical:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies = [
2020
"fitter~=1.4",
2121
"lhsmdu~=1.1",
2222
"networkx~=2.6",
23-
"numpy~=1.22.0",
23+
"numpy~=1.23",
2424
"pandas~=1.3",
2525
"scikit_learn~=1.1",
2626
"scipy~=1.7",

tests/testing_tests/test_causal_test_case.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,6 @@ def setUp(self) -> None:
3737
treatment_value=1,
3838
)
3939

40-
def test_get_treatment_variable(self):
41-
self.assertEqual(self.causal_test_case.get_treatment_variable(), "A")
42-
43-
def test_get_outcome_variable(self):
44-
self.assertEqual(self.causal_test_case.get_outcome_variable(), "C")
45-
46-
def test_get_treatment_value(self):
47-
self.assertEqual(self.causal_test_case.get_treatment_value(), 1)
48-
49-
def test_get_control_value(self):
50-
self.assertEqual(self.causal_test_case.get_control_value(), 0)
51-
5240
def test_str(self):
5341
self.assertEqual(
5442
str(self.causal_test_case),

0 commit comments

Comments
 (0)