diff --git a/causal_testing/main.py b/causal_testing/main.py index c2a78977..9685e723 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -281,17 +281,36 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC if estimator_class is None: raise ValueError(f"Unknown estimator: {test['estimator']}") + # Handle combined queries (global and test-specific) + test_query = test.get("query") + combined_query = None + + if self.query and test_query: + combined_query = f"({self.query}) and ({test_query})" + logger.info( + f"Combining global query '{self.query}' with test-specific query " + f"'{test_query}' for test '{test['name']}'" + ) + elif test_query: + combined_query = test_query + logger.info(f"Using test-specific query for '{test['name']}': {test_query}") + elif self.query: + combined_query = self.query + logger.info(f"Using global query for '{test['name']}': {self.query}") + + filtered_df = self.data.query(combined_query) if combined_query else self.data + # Create the estimator with correct parameters estimator = estimator_class( base_test_case=base_test, treatment_value=test.get("treatment_value"), control_value=test.get("control_value"), adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)), - df=self.data, + df=filtered_df, effect_modifiers=None, formula=test.get("formula"), alpha=test.get("alpha", 0.05), - query="", + query=combined_query, ) # Get effect type and create expected effect @@ -424,7 +443,7 @@ def save_results(self, results: List[CausalTestResult]) -> None: "effect": test_config.get("effect", "direct"), "treatment_variable": test_config["treatment_variable"], "expected_effect": test_config["expected_effect"], - "formula": test_config.get("formula"), + "formula": result.estimator.formula if hasattr(result.estimator, "formula") else None, "alpha": test_config.get("alpha", 0.05), "skip": test_config.get("skip", False), "passed": test_passed, diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index 1a0538f1..81e46742 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -144,6 +144,67 @@ def test_ctf(self): self.assertEqual(tests_passed, [True]) + def test_global_query(self): + framework = CausalTestingFramework(self.paths) + framework.setup() + + query_framework = CausalTestingFramework(self.paths, query="test_input > 0") + query_framework.setup() + + self.assertTrue(len(query_framework.data) > 0) + self.assertTrue((query_framework.data["test_input"] > 0).all()) + + with open(self.test_config_path, "r", encoding="utf-8") as f: + test_configs = json.load(f) + + test_config = test_configs["tests"][0].copy() + if "query" in test_config: + del test_config["query"] + + base_test = query_framework.create_base_test(test_config) + causal_test = query_framework.create_causal_test(test_config, base_test) + + self.assertTrue((causal_test.estimator.df["test_input"] > 0).all()) + + query_framework.create_variables() + query_framework.create_scenario_and_specification() + + self.assertIsNotNone(query_framework.scenario) + self.assertIsNotNone(query_framework.causal_specification) + + def test_test_specific_query(self): + framework = CausalTestingFramework(self.paths) + framework.setup() + + with open(self.test_config_path, "r", encoding="utf-8") as f: + test_configs = json.load(f) + + test_config = test_configs["tests"][0].copy() + test_config["query"] = "test_input > 0" + + base_test = framework.create_base_test(test_config) + causal_test = framework.create_causal_test(test_config, base_test) + + self.assertTrue(len(causal_test.estimator.df) > 0) + self.assertTrue((causal_test.estimator.df["test_input"] > 0).all()) + + def test_combined_queries(self): + global_framework = CausalTestingFramework(self.paths, query="test_input > 0") + global_framework.setup() + + with open(self.test_config_path, "r", encoding="utf-8") as f: + test_configs = json.load(f) + + test_config = test_configs["tests"][0].copy() + test_config["query"] = "test_output > 0" + + base_test = global_framework.create_base_test(test_config) + causal_test = global_framework.create_causal_test(test_config, base_test) + + self.assertTrue(len(causal_test.estimator.df) > 0) + self.assertTrue((causal_test.estimator.df["test_input"] > 0).all()) + self.assertTrue((causal_test.estimator.df["test_output"] > 0).all()) + def test_parse_args(self): with unittest.mock.patch( "sys.argv", diff --git a/tests/resources/data/tests.json b/tests/resources/data/tests.json index 43d875a6..1a92cb2d 100644 --- a/tests/resources/data/tests.json +++ b/tests/resources/data/tests.json @@ -5,6 +5,7 @@ "estimator": "LinearRegressionEstimator", "estimate_type": "coefficient", "effect_modifiers": [], + "query": "test_input > 0", "expected_effect": {"test_output": "NoEffect"}, "skip": false }, @@ -14,6 +15,7 @@ "estimator": "LinearRegressionEstimator", "estimate_type": "coefficient", "effect_modifiers": [], + "query": "test_input <= 5", "expected_effect": {"test_output": "NoEffect"}, "skip": true }]