|
7 | 7 | from abc import abstractmethod
|
8 | 8 | from typing import Iterable
|
9 | 9 | from itertools import combinations
|
10 |
| -import numpy as np |
11 |
| -import pandas as pd |
| 10 | +import argparse |
| 11 | +import logging |
| 12 | +import json |
12 | 13 | import networkx as nx
|
| 14 | +import pandas as pd |
| 15 | +import numpy as np |
13 | 16 |
|
14 | 17 | from causal_testing.specification.causal_specification import CausalDAG, Node
|
15 | 18 | from causal_testing.data_collection.data_collector import ExperimentalDataCollector
|
16 | 19 |
|
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
17 | 22 |
|
18 | 23 | @dataclass(order=True)
|
19 | 24 | class MetamorphicRelation:
|
@@ -246,3 +251,35 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
|
246 | 251 | metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
|
247 | 252 |
|
248 | 253 | return metamorphic_relations
|
| 254 | + |
| 255 | + |
| 256 | +if __name__ == "__main__": # pragma: no cover |
| 257 | + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO) |
| 258 | + parser = argparse.ArgumentParser( |
| 259 | + description="A script for generating metamorphic relations to test the causal relationships in a given DAG." |
| 260 | + ) |
| 261 | + parser.add_argument( |
| 262 | + "--dag_path", |
| 263 | + "-d", |
| 264 | + help="Specify path to file containing the DAG, normally a .dot file.", |
| 265 | + required=True, |
| 266 | + ) |
| 267 | + parser.add_argument( |
| 268 | + "--output_path", |
| 269 | + "-o", |
| 270 | + help="Specify path where tests should be saved, normally a .json file.", |
| 271 | + required=True, |
| 272 | + ) |
| 273 | + args = parser.parse_args() |
| 274 | + |
| 275 | + causal_dag = CausalDAG(args.dag_path) |
| 276 | + relations = generate_metamorphic_relations(causal_dag) |
| 277 | + tests = [ |
| 278 | + relation.to_json_stub(skip=False) |
| 279 | + for relation in relations |
| 280 | + if len(list(causal_dag.graph.predecessors(relation.output_var))) > 0 |
| 281 | + ] |
| 282 | + |
| 283 | + logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.") |
| 284 | + with open(args.output_path, "w", encoding="utf-8") as f: |
| 285 | + json.dump({"tests": tests}, f, indent=2) |
0 commit comments