diff --git a/.github/workflows/ci-tests-drafts.yaml b/.github/workflows/ci-tests-drafts.yaml index 6203892b..5f82a35a 100644 --- a/.github/workflows/ci-tests-drafts.yaml +++ b/.github/workflows/ci-tests-drafts.yaml @@ -22,6 +22,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python --version python -m pip install --upgrade pip pip install -e . pip install -e .[test] diff --git a/causal_testing/testing/causal_test_case.py b/causal_testing/testing/causal_test_case.py index 64fb75cb..c4969ec4 100644 --- a/causal_testing/testing/causal_test_case.py +++ b/causal_testing/testing/causal_test_case.py @@ -47,6 +47,10 @@ def __init__( self.estimator = estimator if estimate_params is None: self.estimate_params = {} + + else: + self.estimate_params = estimate_params + self.effect = base_test_case.effect if effect_modifier_configuration: diff --git a/pyproject.toml b/pyproject.toml index 0eaee0ee..7cdc1cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "fitter~=1.7", "lifelines~=0.29.0", "lhsmdu~=1.1", - "networkx~=3.4", + "networkx>=3.4,<3.5", "numpy~=1.26", "pandas>=2.1", "scikit_learn~=1.4", diff --git a/tests/testing_tests/test_causal_test_case.py b/tests/testing_tests/test_causal_test_case.py index 4c4a69c6..154d76e3 100644 --- a/tests/testing_tests/test_causal_test_case.py +++ b/tests/testing_tests/test_causal_test_case.py @@ -203,3 +203,45 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel ) causal_test_result = self.causal_test_case.execute_test(estimation_model) pd.testing.assert_series_equal(causal_test_result.test_value.value, pd.Series(4.0), atol=1) + + def test_estimate_params_none(self): + """Check that estimate_params defaults to empty dict when None is passed into the estimator object""" + estimator = LinearRegressionEstimator( + base_test_case=self.base_test_case_A_C, + adjustment_set=set(), + control_value=0, + treatment_value=1, + formula="C ~ A + D", + df=self.df, + ) + causal_test_case = CausalTestCase( + base_test_case=self.base_test_case_A_C, + expected_causal_effect=self.expected_causal_effect, + estimate_params=None, + estimator=estimator, + estimate_type="risk_ratio", + ) + self.assertEqual(causal_test_case.estimate_params, {}) + with self.assertRaises(ValueError): + causal_test_case.execute_test() + + def test_estimate_params_with_formula(self): + """Ensure estimate params is handled correctly when a formula is passed into the estimator object""" + estimate_params = {"adjustment_config": {"D": 1}} + estimator = LinearRegressionEstimator( + base_test_case=self.base_test_case_A_C, + adjustment_set=set(), + control_value=0, + treatment_value=1, + formula="C ~ A + D", + df=self.df, + ) + causal_test_case = CausalTestCase( + base_test_case=self.base_test_case_A_C, + expected_causal_effect=self.expected_causal_effect, + estimate_params=estimate_params, + estimate_type="risk_ratio", + estimator=estimator, + ) + self.assertEqual(causal_test_case.estimate_params, estimate_params) + self.assertEqual(round(causal_test_case.execute_test().test_value.value[0], 3), 1.444)