Skip to content

Commit 9edf381

Browse files
Reduce the number of instance attributes
1 parent b227234 commit 9edf381

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

causal_testing/testing/causal_test_result.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""This module contains the CausalTestResult class, which is a container for the results of a causal test, and the
22
TestValue dataclass.
33
"""
4-
from typing import Any, Union
4+
from typing import Any
55
from dataclasses import dataclass
66

7+
from causal_testing.testing.estimators import Estimator
78
from causal_testing.specification.variable import Variable
89

910

@@ -22,21 +23,14 @@ class CausalTestResult:
2223

2324
def __init__(
2425
self,
25-
treatment: Variable,
26-
outcome: Variable,
27-
treatment_value: Union[int, float, str],
28-
control_value: Union[int, float, str],
29-
adjustment_set: set,
26+
estimator: Estimator,
3027
test_value: TestValue,
3128
confidence_intervals: [float, float] = None,
3229
effect_modifier_configuration: {Variable: Any} = None,
3330
):
34-
self.treatment = treatment
35-
self.outcome = outcome
36-
self.treatment_value = treatment_value
37-
self.control_value = control_value
38-
if adjustment_set:
39-
self.adjustment_set = adjustment_set
31+
self.estimator = estimator
32+
if estimator.adjustment_set:
33+
self.adjustment_set = estimator.adjustment_set
4034
else:
4135
self.adjustment_set = set()
4236
self.test_value = test_value
@@ -50,10 +44,10 @@ def __init__(
5044
def __str__(self):
5145
base_str = (
5246
f"Causal Test Result\n==============\n"
53-
f"Treatment: {self.treatment[0]}\n"
54-
f"Control value: {self.control_value}\n"
55-
f"Treatment value: {self.treatment_value}\n"
56-
f"Outcome: {self.outcome[0]}\n"
47+
f"Treatment: {self.estimator.treatment[0]}\n"
48+
f"Control value: {self.estimator.control_value}\n"
49+
f"Treatment value: {self.estimator.treatment_value}\n"
50+
f"Outcome: {self.estimator.outcome[0]}\n"
5751
f"Adjustment set: {self.adjustment_set}\n"
5852
f"{self.test_value.type}: {self.test_value.value}\n"
5953
)
@@ -67,10 +61,10 @@ def to_dict(self):
6761
:return: Dictionary containing contents of causal_test_result
6862
"""
6963
base_dict = {
70-
"treatment": self.treatment[0],
71-
"control_value": self.control_value,
72-
"treatment_value": self.treatment_value,
73-
"outcome": self.outcome[0],
64+
"treatment": self.estimator.treatment[0],
65+
"control_value": self.estimator.control_value,
66+
"treatment_value": self.estimator.treatment_value,
67+
"outcome": self.estimator.outcome[0],
7468
"adjustment_set": self.adjustment_set,
7569
"test_value": self.test_value,
7670
}
@@ -94,7 +88,7 @@ def ci_high(self):
9488
def summary(self):
9589
"""Summarise the causal test result as an intuitive sentence."""
9690
print(
97-
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
98-
f"{self.treatment[0]}' = {self.treatment_value} is {self.test_value.value} (95% confidence intervals: "
99-
f"{self.confidence_intervals})."
91+
f"The causal effect of changing {self.estimator.treatment[0]} = {self.estimator.control_value} to "
92+
f"{self.estimator.treatment[0]}' = {self.estimator.treatment_value} is {self.test_value.value}"
93+
f"(95% confidence intervals: {self.confidence_intervals})."
10094
)

0 commit comments

Comments
 (0)