Skip to content

Commit 2ca274d

Browse files
committed
black
1 parent c8aa97e commit 2ca274d

File tree

3 files changed

+57
-36
lines changed

3 files changed

+57
-36
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
7171
minimal_adjustment_set = minimal_adjustment_set - set(edge.treatment_variable.name)
7272
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
7373

74-
variables_for_positivity = (
75-
list(minimal_adjustment_set) + [edge.treatment_variable.name, edge.outcome_variable.name]
76-
)
74+
variables_for_positivity = list(minimal_adjustment_set) + [
75+
edge.treatment_variable.name,
76+
edge.outcome_variable.name,
77+
]
7778

7879
if self._check_positivity_violation(variables_for_positivity):
7980
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
@@ -210,7 +211,9 @@ def _check_positivity_violation(self, variables_list):
210211
:param variables_list: The list of variables for which positivity must be satisfied.
211212
:return: True if positivity is violated, False otherwise.
212213
"""
213-
if not (set(variables_list) - {x.name for x in self.scenario.hidden_variables()}).issubset(self.scenario_execution_data_df.columns):
214+
if not (set(variables_list) - {x.name for x in self.scenario.hidden_variables()}).issubset(
215+
self.scenario_execution_data_df.columns
216+
):
214217
missing_variables = set(variables_list) - set(self.scenario_execution_data_df.columns)
215218
logger.warning(
216219
"Positivity violation: missing data for variables %s.\n"

tests/json_front_tests/test_json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
def test_setting_paths(self):
4141
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
4242
self.assertEqual(self.json_class.paths.dag_path, Path(self.dag_path))
43-
self.assertEqual(self.json_class.paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
43+
self.assertEqual(self.json_class.paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
4444

4545
def test_set_inputs(self):
4646
ctf_input = [Input("test_input", float, self.example_distribution)]

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ def test_None_ci(self):
2727

2828
self.assertIsNone(ctr.ci_low())
2929
self.assertIsNone(ctr.ci_high())
30-
self.assertEqual(ctr.to_dict(),
31-
{"treatment": "A",
32-
"control_value": 0,
33-
"treatment_value": 1,
34-
"outcome": "A",
35-
"adjustment_set": set(),
36-
"test_value": test_value})
30+
self.assertEqual(
31+
ctr.to_dict(),
32+
{
33+
"treatment": "A",
34+
"control_value": 0,
35+
"treatment_value": 1,
36+
"outcome": "A",
37+
"adjustment_set": set(),
38+
"test_value": test_value,
39+
},
40+
)
3741

3842
def test_empty_adjustment_set(self):
3943
test_value = TestValue(type="ate", value=0)
@@ -46,13 +50,18 @@ def test_empty_adjustment_set(self):
4650

4751
self.assertIsNone(ctr.ci_low())
4852
self.assertIsNone(ctr.ci_high())
49-
self.assertEqual(str(ctr), ("Causal Test Result\n==============\n"
50-
"Treatment: A\n"
51-
"Control value: 0\n"
52-
"Treatment value: 1\n"
53-
"Outcome: A\n"
54-
"Adjustment set: set()\n"
55-
"ate: 0\n" ))
53+
self.assertEqual(
54+
str(ctr),
55+
(
56+
"Causal Test Result\n==============\n"
57+
"Treatment: A\n"
58+
"Control value: 0\n"
59+
"Treatment value: 1\n"
60+
"Outcome: A\n"
61+
"Adjustment set: set()\n"
62+
"ate: 0\n"
63+
),
64+
)
5665

5766
def test_exactValue_pass(self):
5867
test_value = TestValue(type="ate", value=5.05)
@@ -97,20 +106,29 @@ def test_someEffect_fail(self):
97106
)
98107
ev = SomeEffect()
99108
self.assertFalse(ev.apply(ctr))
100-
self.assertEqual(str(ctr), ("Causal Test Result\n==============\n"
101-
"Treatment: A\n"
102-
"Control value: 0\n"
103-
"Treatment value: 1\n"
104-
"Outcome: A\n"
105-
"Adjustment set: set()\n"
106-
"ate: 0\n"
107-
"Confidence intervals: [-0.1, 0.2]\n" ))
108-
self.assertEqual(ctr.to_dict(),
109-
{"treatment": "A",
110-
"control_value": 0,
111-
"treatment_value": 1,
112-
"outcome": "A",
113-
"adjustment_set": set(),
114-
"test_value": test_value,
115-
"ci_low": -0.1,
116-
"ci_high": 0.2})
109+
self.assertEqual(
110+
str(ctr),
111+
(
112+
"Causal Test Result\n==============\n"
113+
"Treatment: A\n"
114+
"Control value: 0\n"
115+
"Treatment value: 1\n"
116+
"Outcome: A\n"
117+
"Adjustment set: set()\n"
118+
"ate: 0\n"
119+
"Confidence intervals: [-0.1, 0.2]\n"
120+
),
121+
)
122+
self.assertEqual(
123+
ctr.to_dict(),
124+
{
125+
"treatment": "A",
126+
"control_value": 0,
127+
"treatment_value": 1,
128+
"outcome": "A",
129+
"adjustment_set": set(),
130+
"test_value": test_value,
131+
"ci_low": -0.1,
132+
"ci_high": 0.2,
133+
},
134+
)

0 commit comments

Comments
 (0)