Skip to content

Commit e7b8577

Browse files
authored
Merge pull request #328 from CITCOM-project/f-allian/chore
Enforce NetworkX 3.4
2 parents 7da3aef + a940927 commit e7b8577

File tree

4 files changed

+48
-1
lines changed

4 files changed

+48
-1
lines changed

.github/workflows/ci-tests-drafts.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
python-version: ${{ matrix.python-version }}
2323
- name: Install dependencies
2424
run: |
25+
python --version
2526
python -m pip install --upgrade pip
2627
pip install -e .
2728
pip install -e .[test]

causal_testing/testing/causal_test_case.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __init__(
4747
self.estimator = estimator
4848
if estimate_params is None:
4949
self.estimate_params = {}
50+
51+
else:
52+
self.estimate_params = estimate_params
53+
5054
self.effect = base_test_case.effect
5155

5256
if effect_modifier_configuration:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"fitter~=1.7",
1919
"lifelines~=0.29.0",
2020
"lhsmdu~=1.1",
21-
"networkx~=3.4",
21+
"networkx>=3.4,<3.5",
2222
"numpy~=1.26",
2323
"pandas>=2.1",
2424
"scikit_learn~=1.4",

tests/testing_tests/test_causal_test_case.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,45 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
203203
)
204204
causal_test_result = self.causal_test_case.execute_test(estimation_model)
205205
pd.testing.assert_series_equal(causal_test_result.test_value.value, pd.Series(4.0), atol=1)
206+
207+
def test_estimate_params_none(self):
208+
"""Check that estimate_params defaults to empty dict when None is passed into the estimator object"""
209+
estimator = LinearRegressionEstimator(
210+
base_test_case=self.base_test_case_A_C,
211+
adjustment_set=set(),
212+
control_value=0,
213+
treatment_value=1,
214+
formula="C ~ A + D",
215+
df=self.df,
216+
)
217+
causal_test_case = CausalTestCase(
218+
base_test_case=self.base_test_case_A_C,
219+
expected_causal_effect=self.expected_causal_effect,
220+
estimate_params=None,
221+
estimator=estimator,
222+
estimate_type="risk_ratio",
223+
)
224+
self.assertEqual(causal_test_case.estimate_params, {})
225+
with self.assertRaises(ValueError):
226+
causal_test_case.execute_test()
227+
228+
def test_estimate_params_with_formula(self):
229+
"""Ensure estimate params is handled correctly when a formula is passed into the estimator object"""
230+
estimate_params = {"adjustment_config": {"D": 1}}
231+
estimator = LinearRegressionEstimator(
232+
base_test_case=self.base_test_case_A_C,
233+
adjustment_set=set(),
234+
control_value=0,
235+
treatment_value=1,
236+
formula="C ~ A + D",
237+
df=self.df,
238+
)
239+
causal_test_case = CausalTestCase(
240+
base_test_case=self.base_test_case_A_C,
241+
expected_causal_effect=self.expected_causal_effect,
242+
estimate_params=estimate_params,
243+
estimate_type="risk_ratio",
244+
estimator=estimator,
245+
)
246+
self.assertEqual(causal_test_case.estimate_params, estimate_params)
247+
self.assertEqual(round(causal_test_case.execute_test().test_value.value[0], 3), 1.444)

0 commit comments

Comments
 (0)