Skip to content

Commit 8419eb2

Browse files
authored
Merge pull request #316 from CITCOM-project/f-allian/front-end
Functionality for test-specific queries
2 parents 3bb31e9 + 8fa945b commit 8419eb2

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

causal_testing/main.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,17 +281,36 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
281281
if estimator_class is None:
282282
raise ValueError(f"Unknown estimator: {test['estimator']}")
283283

284+
# Handle combined queries (global and test-specific)
285+
test_query = test.get("query")
286+
combined_query = None
287+
288+
if self.query and test_query:
289+
combined_query = f"({self.query}) and ({test_query})"
290+
logger.info(
291+
f"Combining global query '{self.query}' with test-specific query "
292+
f"'{test_query}' for test '{test['name']}'"
293+
)
294+
elif test_query:
295+
combined_query = test_query
296+
logger.info(f"Using test-specific query for '{test['name']}': {test_query}")
297+
elif self.query:
298+
combined_query = self.query
299+
logger.info(f"Using global query for '{test['name']}': {self.query}")
300+
301+
filtered_df = self.data.query(combined_query) if combined_query else self.data
302+
284303
# Create the estimator with correct parameters
285304
estimator = estimator_class(
286305
base_test_case=base_test,
287306
treatment_value=test.get("treatment_value"),
288307
control_value=test.get("control_value"),
289308
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
290-
df=self.data,
309+
df=filtered_df,
291310
effect_modifiers=None,
292311
formula=test.get("formula"),
293312
alpha=test.get("alpha", 0.05),
294-
query="",
313+
query=combined_query,
295314
)
296315

297316
# Get effect type and create expected effect
@@ -424,7 +443,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
424443
"effect": test_config.get("effect", "direct"),
425444
"treatment_variable": test_config["treatment_variable"],
426445
"expected_effect": test_config["expected_effect"],
427-
"formula": test_config.get("formula"),
446+
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
428447
"alpha": test_config.get("alpha", 0.05),
429448
"skip": test_config.get("skip", False),
430449
"passed": test_passed,

tests/main_tests/test_main.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,67 @@ def test_ctf(self):
144144

145145
self.assertEqual(tests_passed, [True])
146146

147+
def test_global_query(self):
148+
framework = CausalTestingFramework(self.paths)
149+
framework.setup()
150+
151+
query_framework = CausalTestingFramework(self.paths, query="test_input > 0")
152+
query_framework.setup()
153+
154+
self.assertTrue(len(query_framework.data) > 0)
155+
self.assertTrue((query_framework.data["test_input"] > 0).all())
156+
157+
with open(self.test_config_path, "r", encoding="utf-8") as f:
158+
test_configs = json.load(f)
159+
160+
test_config = test_configs["tests"][0].copy()
161+
if "query" in test_config:
162+
del test_config["query"]
163+
164+
base_test = query_framework.create_base_test(test_config)
165+
causal_test = query_framework.create_causal_test(test_config, base_test)
166+
167+
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
168+
169+
query_framework.create_variables()
170+
query_framework.create_scenario_and_specification()
171+
172+
self.assertIsNotNone(query_framework.scenario)
173+
self.assertIsNotNone(query_framework.causal_specification)
174+
175+
def test_test_specific_query(self):
176+
framework = CausalTestingFramework(self.paths)
177+
framework.setup()
178+
179+
with open(self.test_config_path, "r", encoding="utf-8") as f:
180+
test_configs = json.load(f)
181+
182+
test_config = test_configs["tests"][0].copy()
183+
test_config["query"] = "test_input > 0"
184+
185+
base_test = framework.create_base_test(test_config)
186+
causal_test = framework.create_causal_test(test_config, base_test)
187+
188+
self.assertTrue(len(causal_test.estimator.df) > 0)
189+
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
190+
191+
def test_combined_queries(self):
192+
global_framework = CausalTestingFramework(self.paths, query="test_input > 0")
193+
global_framework.setup()
194+
195+
with open(self.test_config_path, "r", encoding="utf-8") as f:
196+
test_configs = json.load(f)
197+
198+
test_config = test_configs["tests"][0].copy()
199+
test_config["query"] = "test_output > 0"
200+
201+
base_test = global_framework.create_base_test(test_config)
202+
causal_test = global_framework.create_causal_test(test_config, base_test)
203+
204+
self.assertTrue(len(causal_test.estimator.df) > 0)
205+
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
206+
self.assertTrue((causal_test.estimator.df["test_output"] > 0).all())
207+
147208
def test_parse_args(self):
148209
with unittest.mock.patch(
149210
"sys.argv",

tests/resources/data/tests.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"estimator": "LinearRegressionEstimator",
66
"estimate_type": "coefficient",
77
"effect_modifiers": [],
8+
"query": "test_input > 0",
89
"expected_effect": {"test_output": "NoEffect"},
910
"skip": false
1011
},
@@ -14,6 +15,7 @@
1415
"estimator": "LinearRegressionEstimator",
1516
"estimate_type": "coefficient",
1617
"effect_modifiers": [],
18+
"query": "test_input <= 5",
1719
"expected_effect": {"test_output": "NoEffect"},
1820
"skip": true
1921
}]

0 commit comments

Comments
 (0)