@@ -144,6 +144,55 @@ def test_ctf(self):
144
144
145
145
self .assertEqual (tests_passed , [True ])
146
146
147
+ def test_global_query (self ):
148
+ framework = CausalTestingFramework (self .paths )
149
+ framework .setup ()
150
+
151
+ query_framework = CausalTestingFramework (self .paths , query = "test_input > 0" )
152
+ query_framework .setup ()
153
+
154
+ self .assertTrue (len (query_framework .data ) > 0 )
155
+ self .assertTrue ((query_framework .data ["test_input" ] > 0 ).all ())
156
+
157
+ query_framework .create_variables ()
158
+ query_framework .create_scenario_and_specification ()
159
+
160
+ self .assertIsNotNone (query_framework .scenario )
161
+ self .assertIsNotNone (query_framework .causal_specification )
162
+
163
+ def test_test_specific_query (self ):
164
+ framework = CausalTestingFramework (self .paths )
165
+ framework .setup ()
166
+
167
+ with open (self .test_config_path , "r" , encoding = "utf-8" ) as f :
168
+ test_configs = json .load (f )
169
+
170
+ test_config = test_configs ["tests" ][0 ].copy ()
171
+ test_config ["query" ] = "test_input > 0"
172
+
173
+ base_test = framework .create_base_test (test_config )
174
+ causal_test = framework .create_causal_test (test_config , base_test )
175
+
176
+ self .assertTrue (len (causal_test .estimator .df ) > 0 )
177
+ self .assertTrue ((causal_test .estimator .df ["test_input" ] > 0 ).all ())
178
+
179
+ def test_combined_queries (self ):
180
+ global_framework = CausalTestingFramework (self .paths , query = "test_input > 0" )
181
+ global_framework .setup ()
182
+
183
+ with open (self .test_config_path , "r" , encoding = "utf-8" ) as f :
184
+ test_configs = json .load (f )
185
+
186
+ test_config = test_configs ["tests" ][0 ].copy ()
187
+ test_config ["query" ] = "test_output > 0"
188
+
189
+ base_test = global_framework .create_base_test (test_config )
190
+ causal_test = global_framework .create_causal_test (test_config , base_test )
191
+
192
+ self .assertTrue (len (causal_test .estimator .df ) > 0 )
193
+ self .assertTrue ((causal_test .estimator .df ["test_input" ] > 0 ).all ())
194
+ self .assertTrue ((causal_test .estimator .df ["test_output" ] > 0 ).all ())
195
+
147
196
def test_parse_args (self ):
148
197
with unittest .mock .patch (
149
198
"sys.argv" ,
0 commit comments