3
3
import shutil , tempfile
4
4
import pandas as pd
5
5
from itertools import combinations
6
+ import tempfile
7
+ import json
6
8
7
9
from causal_testing .specification .causal_dag import CausalDAG
8
10
from causal_testing .specification .causal_specification import Scenario
11
13
ShouldNotCause ,
12
14
generate_metamorphic_relations ,
13
15
generate_metamorphic_relation ,
16
+ generate_causal_tests ,
14
17
)
15
18
from causal_testing .specification .variable import Input , Output
16
19
from causal_testing .testing .base_test_case import BaseTestCase
@@ -177,8 +180,8 @@ def test_all_metamorphic_relations_implied_by_dag_parallel(self):
177
180
self .assertEqual (missing_snc_relations , [])
178
181
179
182
def test_all_metamorphic_relations_implied_by_dag_ignore_cycles (self ):
180
- dag = CausalDAG (self .dcg_dot_path , ignore_cycles = True )
181
- metamorphic_relations = generate_metamorphic_relations (dag , threads = 2 , nodes_to_ignore = set (dag .cycle_nodes ()))
183
+ dcg = CausalDAG (self .dcg_dot_path , ignore_cycles = True )
184
+ metamorphic_relations = generate_metamorphic_relations (dcg , threads = 2 , nodes_to_ignore = set (dcg .cycle_nodes ()))
182
185
should_cause_relations = [mr for mr in metamorphic_relations if isinstance (mr , ShouldCause )]
183
186
should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance (mr , ShouldNotCause )]
184
187
@@ -203,6 +206,46 @@ def test_generate_metamorphic_relation_(self):
203
206
ShouldCause (BaseTestCase ("X1" , "Z" ), []),
204
207
)
205
208
209
+ def test_generate_causal_tests_ignore_cycles (self ):
210
+ dcg = CausalDAG (self .dcg_dot_path , ignore_cycles = True )
211
+ relations = generate_metamorphic_relations (dcg , nodes_to_ignore = set (dcg .cycle_nodes ()))
212
+ with tempfile .TemporaryDirectory () as tmp :
213
+ tests_file = os .path .join (tmp , "causal_tests.json" )
214
+ generate_causal_tests (self .dcg_dot_path , tests_file , ignore_cycles = True )
215
+ with open (tests_file , encoding = "utf8" ) as f :
216
+ tests = json .load (f )
217
+ expected = list (
218
+ map (
219
+ lambda x : x .to_json_stub (skip = False ),
220
+ filter (
221
+ lambda relation : len (list (dcg .graph .predecessors (relation .base_test_case .outcome_variable )))
222
+ > 0 ,
223
+ relations ,
224
+ ),
225
+ )
226
+ )
227
+ self .assertEqual (tests ["tests" ], expected )
228
+
229
+ def test_generate_causal_tests (self ):
230
+ dag = CausalDAG (self .dag_dot_path )
231
+ relations = generate_metamorphic_relations (dag )
232
+ with tempfile .TemporaryDirectory () as tmp :
233
+ tests_file = os .path .join (tmp , "causal_tests.json" )
234
+ generate_causal_tests (self .dag_dot_path , tests_file )
235
+ with open (tests_file , encoding = "utf8" ) as f :
236
+ tests = json .load (f )
237
+ expected = list (
238
+ map (
239
+ lambda x : x .to_json_stub (skip = False ),
240
+ filter (
241
+ lambda relation : len (list (dag .graph .predecessors (relation .base_test_case .outcome_variable )))
242
+ > 0 ,
243
+ relations ,
244
+ ),
245
+ )
246
+ )
247
+ self .assertEqual (tests ["tests" ], expected )
248
+
206
249
def test_shoud_cause_string (self ):
207
250
sc_mr = ShouldCause (BaseTestCase ("X" , "Y" ), ["A" , "B" , "C" ])
208
251
self .assertEqual (str (sc_mr ), "X --> Y | ['A', 'B', 'C']" )
0 commit comments