|
22 | 22 | from causal_testing.specification.variable import Input, Meta, Output
|
23 | 23 | from causal_testing.testing.causal_test_case import CausalTestCase
|
24 | 24 | from causal_testing.testing.causal_test_result import CausalTestResult
|
25 |
| -from causal_testing.testing.estimators import Estimator |
| 25 | +from causal_testing.testing.estimators import Estimator, LinearRegressionEstimator, LogisticRegressionEstimator |
26 | 26 | from causal_testing.testing.base_test_case import BaseTestCase
|
27 | 27 | from causal_testing.testing.causal_test_adequacy import DataAdequacy
|
28 | 28 |
|
@@ -307,20 +307,29 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
|
307 | 307 | :returns:
|
308 | 308 | - estimation_model - Estimator instance for the test being run
|
309 | 309 | """
|
310 |
| - minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case) |
311 |
| - treatment_var = causal_test_case.treatment_variable |
312 |
| - minimal_adjustment_set = minimal_adjustment_set - {treatment_var} |
313 |
| - estimator_kwargs = { |
314 |
| - "treatment": treatment_var.name, |
315 |
| - "treatment_value": causal_test_case.treatment_value, |
316 |
| - "control_value": causal_test_case.control_value, |
317 |
| - "adjustment_set": minimal_adjustment_set, |
318 |
| - "outcome": causal_test_case.outcome_variable.name, |
319 |
| - "effect_modifiers": causal_test_case.effect_modifier_configuration, |
320 |
| - "alpha": test["alpha"] if "alpha" in test else 0.05, |
321 |
| - } |
| 310 | + estimator_kwargs = {} |
322 | 311 | if "formula" in test:
|
| 312 | + if test["estimator"] != (LinearRegressionEstimator or LogisticRegressionEstimator): |
| 313 | + raise TypeError( |
| 314 | + "Currently only LinearRegressionEstimator and LogisticRegressionEstimator supports the use of " |
| 315 | + "formulas" |
| 316 | + ) |
323 | 317 | estimator_kwargs["formula"] = test["formula"]
|
| 318 | + estimator_kwargs["adjustment_set"] = None |
| 319 | + else: |
| 320 | + minimal_adjustment_set = self.causal_specification.causal_dag.identification( |
| 321 | + causal_test_case.base_test_case |
| 322 | + ) |
| 323 | + minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable} |
| 324 | + estimator_kwargs["adjustment_set"] = minimal_adjustment_set |
| 325 | + |
| 326 | + estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name |
| 327 | + estimator_kwargs["treatment_value"] = causal_test_case.treatment_value |
| 328 | + estimator_kwargs["control_value"] = causal_test_case.control_value |
| 329 | + estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name |
| 330 | + estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration |
| 331 | + estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05 |
| 332 | + |
324 | 333 | estimation_model = test["estimator"](**estimator_kwargs)
|
325 | 334 | return estimation_model
|
326 | 335 |
|
|
0 commit comments