@@ -214,44 +214,63 @@ def __str__(self):
214
214
)
215
215
216
216
217
- def generate_metamorphic_relations (dag : CausalDAG ) -> list [MetamorphicRelation ]:
218
- """Construct a list of metamorphic relations implied by the Causal DAG.
217
+ def generate_metamorphic_relations (dag : CausalDAG , skip_ancestors : bool ) -> list [MetamorphicRelation ]:
218
+ """Construct a list of metamorphic relations based on the DAG or cyclic graph .
219
219
220
- This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
220
+ If is_causal_dag is True, this list contains a ShouldCause relation for every edge, and a ShouldNotCause
221
221
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
222
+ If is_causal_dag is False, it skips checks assuming the graph is acyclic and works on general graphs with loops/cycles.
222
223
223
- :param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
224
+ :param CausalDAG dag: Graph from which the metamorphic relations will be generated.
225
+ :param bool is_causal_dag: Specifies whether the input graph is a causal DAG or a cyclic graph.
224
226
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
225
227
"""
226
228
metamorphic_relations = []
229
+
227
230
for node_pair in combinations (dag .graph .nodes , 2 ):
228
231
(u , v ) = node_pair
229
232
230
- # Create a ShouldNotCause relation for each pair of nodes that are not directly connected
231
- if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
232
- # Case 1: U --> ... --> V
233
- if u in nx .ancestors (dag .graph , v ):
233
+ # If the graph is a causal DAG, perform the ancestor checks
234
+ if not skip_ancestors :
235
+ # Create a ShouldNotCause relation for each pair of nodes that are not directly connected
236
+ if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
237
+ # Case 1: U --> ... --> V
238
+ if u in nx .ancestors (dag .graph , v ):
239
+ adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
240
+ metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
241
+
242
+ # Case 2: V --> ... --> U
243
+ elif v in nx .ancestors (dag .graph , u ):
244
+ adj_set = list (dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
245
+ metamorphic_relations .append (ShouldNotCause (v , u , adj_set , dag ))
246
+
247
+ # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
248
+ else :
249
+ adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
250
+ metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
251
+
252
+ # Create a ShouldCause relation for each edge (u, v) or (v, u)
253
+ elif (u , v ) in dag .graph .edges :
234
254
adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
235
- metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
236
-
237
- # Case 2: V --> ... --> U
238
- elif v in nx .ancestors (dag .graph , u ):
255
+ metamorphic_relations .append (ShouldCause (u , v , adj_set , dag ))
256
+ else :
239
257
adj_set = list (dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
240
- metamorphic_relations .append (ShouldNotCause (v , u , adj_set , dag ))
258
+ metamorphic_relations .append (ShouldCause (v , u , adj_set , dag ))
241
259
242
- # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
243
- # Only make one MR since V _||_ U == U _||_ V
244
- else :
245
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
260
+ # If the graph may contain loops/cycles, skip the ancestor checks
261
+ else :
262
+ # Create a ShouldNotCause relation for nodes that are not directly connected
263
+ if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
264
+ adj_set = [] # No adjustment sets needed for cyclic graphs
246
265
metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
247
266
248
- # Create a ShouldCause relation for each edge (u, v) or (v, u)
249
- elif (u , v ) in dag .graph .edges :
250
- adj_set = list ( dag . direct_effect_adjustment_sets ([ u ], [ v ])[ 0 ])
251
- metamorphic_relations .append (ShouldCause (u , v , adj_set , dag ))
252
- else :
253
- adj_set = list ( dag . direct_effect_adjustment_sets ([ v ], [ u ])[ 0 ])
254
- metamorphic_relations .append (ShouldCause (v , u , adj_set , dag ))
267
+ # Create a ShouldCause relation for each edge (u, v) or (v, u)
268
+ elif (u , v ) in dag .graph .edges :
269
+ adj_set = [] # No adjustment sets needed for cyclic graphs
270
+ metamorphic_relations .append (ShouldCause (u , v , adj_set , dag ))
271
+ else :
272
+ adj_set = [] # No adjustment sets needed for cyclic graphs
273
+ metamorphic_relations .append (ShouldCause (v , u , adj_set , dag ))
255
274
256
275
return metamorphic_relations
257
276
@@ -273,10 +292,24 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
273
292
help = "Specify path where tests should be saved, normally a .json file." ,
274
293
required = True ,
275
294
)
295
+
296
+ parser .add_argument (
297
+ "--skip_ancestors" ,
298
+ "-c" ,
299
+ action = "store_true" ,
300
+ default = False ,
301
+ help = "Boolean flag to indicate if ancestors should be skipped. Default is False" ,
302
+ required = False ,
303
+ )
304
+
276
305
args = parser .parse_args ()
277
306
278
307
causal_dag = CausalDAG (args .dag_path )
279
- relations = generate_metamorphic_relations (causal_dag )
308
+ relations = generate_metamorphic_relations (causal_dag , skip_ancestors = args .skip_ancestors )
309
+
310
+ if args .skip_ancestors :
311
+ logger .warning ("The 'skip_ancestors' variable is set to True, proceed with caution." )
312
+
280
313
tests = [
281
314
relation .to_json_stub (skip = False )
282
315
for relation in relations
0 commit comments