Skip to content

Commit 71b5c5f

Browse files
committed
fix: tests for global and test-specific queries
1 parent e6f27fe commit 71b5c5f

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

tests/main_tests/test_main.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,55 @@ def test_ctf(self):
144144

145145
self.assertEqual(tests_passed, [True])
146146

147+
def test_global_query(self):
148+
framework = CausalTestingFramework(self.paths)
149+
framework.setup()
150+
151+
query_framework = CausalTestingFramework(self.paths, query="test_input > 0")
152+
query_framework.setup()
153+
154+
self.assertTrue(len(query_framework.data) > 0)
155+
self.assertTrue((query_framework.data["test_input"] > 0).all())
156+
157+
query_framework.create_variables()
158+
query_framework.create_scenario_and_specification()
159+
160+
self.assertIsNotNone(query_framework.scenario)
161+
self.assertIsNotNone(query_framework.causal_specification)
162+
163+
def test_test_specific_query(self):
164+
framework = CausalTestingFramework(self.paths)
165+
framework.setup()
166+
167+
with open(self.test_config_path, "r", encoding="utf-8") as f:
168+
test_configs = json.load(f)
169+
170+
test_config = test_configs["tests"][0].copy()
171+
test_config["query"] = "test_input > 0"
172+
173+
base_test = framework.create_base_test(test_config)
174+
causal_test = framework.create_causal_test(test_config, base_test)
175+
176+
self.assertTrue(len(causal_test.estimator.df) > 0)
177+
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
178+
179+
def test_combined_queries(self):
180+
global_framework = CausalTestingFramework(self.paths, query="test_input > 0")
181+
global_framework.setup()
182+
183+
with open(self.test_config_path, "r", encoding="utf-8") as f:
184+
test_configs = json.load(f)
185+
186+
test_config = test_configs["tests"][0].copy()
187+
test_config["query"] = "test_output > 0"
188+
189+
base_test = global_framework.create_base_test(test_config)
190+
causal_test = global_framework.create_causal_test(test_config, base_test)
191+
192+
self.assertTrue(len(causal_test.estimator.df) > 0)
193+
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
194+
self.assertTrue((causal_test.estimator.df["test_output"] > 0).all())
195+
147196
def test_parse_args(self):
148197
with unittest.mock.patch(
149198
"sys.argv",

tests/resources/data/tests.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"estimator": "LinearRegressionEstimator",
66
"estimate_type": "coefficient",
77
"effect_modifiers": [],
8+
"query": "test_input > 0",
89
"expected_effect": {"test_output": "NoEffect"},
910
"skip": false
1011
},
@@ -14,6 +15,7 @@
1415
"estimator": "LinearRegressionEstimator",
1516
"estimate_type": "coefficient",
1617
"effect_modifiers": [],
18+
"query": "test_input <= 5",
1719
"expected_effect": {"test_output": "NoEffect"},
1820
"skip": true
1921
}]

0 commit comments

Comments
 (0)