Skip to content

Commit 6efe594

Browse files
committed
increased coverage a little
1 parent 673165c commit 6efe594

File tree

8 files changed

+24
-21
lines changed

8 files changed

+24
-21
lines changed

causal_testing/estimation/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ def __init__(
5050

5151
if effect_modifiers is None:
5252
self.effect_modifiers = {}
53-
elif isinstance(effect_modifiers, dict):
54-
self.effect_modifiers = effect_modifiers
5553
else:
56-
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
54+
self.effect_modifiers = effect_modifiers
5755
self.modelling_assumptions = []
5856
if query:
5957
self.modelling_assumptions.append(query)

causal_testing/estimation/gp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ def reciprocal(x: float) -> float:
3131

3232
def mut_insert(expression: gp.PrimitiveTree, pset: gp.PrimitiveSet):
3333
"""
34-
Copied from gp.mutInsert, except that we import isclass from inspect, so we
35-
won't have the "isclass not defined" bug.
34+
NOTE: This is a temporary workaround. This method is copied verbatim from
35+
gp.mutInsert. It seems they forgot to import isclass from inspect, so their
36+
method throws an error, saying that "isclass is not defined". A couple of
37+
lines are not covered by tests, but since this is 1. a temporary workaround
38+
until they release a new version of DEAP, and 2. not our code, I don't think
39+
that matters.
3640
3741
Inserts a new branch at a random position in *expression*. The subtree
3842
at the chosen position is used as child node of the created subtree, in
@@ -374,6 +378,4 @@ def mutate(self, expression: gp.PrimitiveTree) -> gp.PrimitiveTree:
374378
mutated = mut_insert(self.toolbox.clone(expression), self.pset)
375379
elif choice == 3:
376380
mutated = gp.mutShrink(self.toolbox.clone(expression))
377-
else:
378-
raise ValueError("Invalid mutation choice")
379381
return mutated

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series
164164
165165
:return: The average treatment effect and the 95% Wald confidence intervals.
166166
"""
167-
if adjustment_config is None:
168-
adjustment_config = {}
169167
prediction = self._predict(adjustment_config=adjustment_config)
170168
control_outcome, treatment_outcome = prediction.iloc[1], prediction.iloc[0]
171169
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"])
@@ -184,8 +182,6 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se
184182
185183
:return: The average treatment effect and the 95% Wald confidence intervals.
186184
"""
187-
if adjustment_config is None:
188-
adjustment_config = {}
189185
prediction = self._predict(adjustment_config=adjustment_config)
190186
control_outcome, treatment_outcome = prediction.iloc[1], prediction.iloc[0]
191187
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"])

tests/estimation_tests/test_gp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import unittest
2+
import pandas as pd
3+
4+
from causal_testing.estimation.gp import GP
5+
6+
7+
class TestGP(unittest.TestCase):
8+
9+
def test_init_invalid_fun_name(self):
10+
with self.assertRaises(ValueError):
11+
GP(df=pd.DataFrame(), features=[], outcome="", max_order=2, sympy_conversions={"power_1": ""})

tests/specification_tests/test_capabilities.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44

55
class TestCapability(unittest.TestCase):
6-
76
"""
87
Test the Capability class for basic methods.
98
"""
@@ -17,7 +16,6 @@ def test_repr(self):
1716

1817

1918
class TestTreatmentSequence(unittest.TestCase):
20-
2119
"""
2220
Test the TreatmentSequence class for basic methods.
2321
"""

tests/specification_tests/test_causal_dag.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from causal_testing.testing.base_test_case import BaseTestCase
99

1010

11-
12-
1311
class TestCausalDAGIssue90(unittest.TestCase):
1412
"""
1513
Test the CausalDAG class for the resolution of Issue 90.
@@ -63,10 +61,11 @@ def test_common_cause(self):
6361
causal_dag.graph.add_edge("U", "I")
6462
with self.assertRaises(ValueError):
6563
causal_dag.check_iv_assumptions("X", "Y", "I")
66-
64+
6765
def tearDown(self) -> None:
6866
shutil.rmtree(self.temp_dir_path)
6967

68+
7069
class TestCausalDAG(unittest.TestCase):
7170
"""
7271
Test the CausalDAG class for creation of Causal Directed Acyclic Graphs (DAGs).
@@ -154,10 +153,11 @@ def test_direct_effect_adjustment_sets_no_adjustment(self):
154153
causal_dag = CausalDAG(self.dag_dot_path)
155154
adjustment_sets = causal_dag.direct_effect_adjustment_sets(["X2"], ["D1"])
156155
self.assertEqual(list(adjustment_sets), [set()])
157-
156+
158157
def tearDown(self) -> None:
159158
shutil.rmtree(self.temp_dir_path)
160159

160+
161161
class TestDAGIdentification(unittest.TestCase):
162162
"""
163163
Test the Causal DAG identification algorithms and supporting algorithms.
@@ -345,6 +345,7 @@ def test_dag_with_non_character_nodes(self):
345345
def tearDown(self) -> None:
346346
shutil.rmtree(self.temp_dir_path)
347347

348+
348349
class TestDependsOnOutputs(unittest.TestCase):
349350
"""
350351
Test the depends_on_outputs method.

tests/specification_tests/test_variable.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class TestVariable(unittest.TestCase):
10-
1110
"""
1211
Test the Variable class for basic methods.
1312
"""
@@ -143,7 +142,6 @@ def test_copy(self):
143142

144143

145144
class TestZ3Methods(unittest.TestCase):
146-
147145
"""
148146
Test the Variable class for Z3 methods.
149147

tests/testing_tests/test_causal_test_case.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def test_check_minimum_adjustment_set(self):
107107
minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
108108
self.assertEqual(minimal_adjustment_set, {"D"})
109109

110-
111110
def test_invalid_causal_effect(self):
112111
"""Check that executing the causal test case returns the correct results for dummy data using a linear
113112
regression estimator."""
@@ -170,7 +169,7 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
170169
)
171170
self.causal_test_case.estimate_type = "coefficient"
172171
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
173-
pd.testing.assert_series_equal(causal_test_result.test_value.value, pd.Series({'D': 0.0}), atol=1e-1)
172+
pd.testing.assert_series_equal(causal_test_result.test_value.value, pd.Series({"D": 0.0}), atol=1e-1)
174173

175174
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
176175
"""Check that executing the causal test case returns the correct results for dummy data using a linear

0 commit comments

Comments
 (0)