Skip to content

Commit e7e02b5

Browse files
author
AndrewC19
committed
Added test generation to base metamorphic relation class
1 parent 8026f79 commit e7e02b5

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed
Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,75 @@
11
from dataclasses import dataclass
22
from abc import abstractmethod
3+
from typing import Iterable
4+
from itertools import combinations
5+
import numpy as np
6+
import pandas as pd
37

8+
from causal_testing.specification.causal_specification import CausalDAG, Node
49

5-
@dataclass(order=True, frozen=True)
10+
@dataclass(order=True)
611
class MetamorphicRelation:
712
"""Class representing a metamorphic relation."""
13+
treatment_var: Node
14+
output_var: Node
15+
adjustment_vars: Iterable[Node]
16+
dag: CausalDAG
17+
tests: Iterable = None
818

9-
@abstractmethod
10-
def generate_follow_up(self, source_input_configuration):
11-
"""Generate a follow-up input configuration from a given source input
12-
configuration."""
13-
...
19+
def generate_follow_up(self,
20+
n_tests: int,
21+
min_val: float,
22+
max_val: float,
23+
seed: int = 0):
24+
"""Generate numerical follow-up input configurations."""
25+
np.random.seed(seed)
26+
27+
# Get set of variables to change, excluding the treatment itself
28+
variables_to_change = set([node for node in self.dag.graph.nodes if
29+
self.dag.graph.in_degree(node) == 0])
30+
if self.adjustment_vars:
31+
variables_to_change |= set(self.adjustment_vars)
32+
if self.treatment_var in variables_to_change:
33+
variables_to_change.remove(self.treatment_var)
34+
35+
# Assign random numerical values to the variables to change
36+
test_inputs = pd.DataFrame(
37+
np.random.randint(min_val, max_val,
38+
size=(n_tests, len(variables_to_change))
39+
),
40+
columns=sorted(variables_to_change)
41+
)
42+
43+
# Enumerate the possible source, follow-up pairs for the treatment
44+
candidate_source_follow_up_pairs = np.array(
45+
list(combinations(range(int(min_val), int(max_val+1)), 2))
46+
)
47+
48+
# Sample without replacement from the possible source, follow-up pairs
49+
sampled_source_follow_up_indices = np.random.choice(
50+
candidate_source_follow_up_pairs.shape[0], n_tests, replace=False
51+
)
52+
53+
follow_up_input = f"{self.treatment_var}\'"
54+
source_follow_up_test_inputs = pd.DataFrame(
55+
candidate_source_follow_up_pairs[sampled_source_follow_up_indices],
56+
columns=sorted([self.treatment_var] + [follow_up_input])
57+
)
58+
source_test_inputs = source_follow_up_test_inputs[[self.treatment_var]]
59+
follow_up_test_inputs = source_follow_up_test_inputs[[follow_up_input]]
60+
follow_up_test_inputs.rename({follow_up_input: self.treatment_var})
61+
62+
# TODO: Add a metamorphic test dataclass that stores these attributes
63+
self.tests = list(
64+
zip(
65+
source_test_inputs.to_dict(orient="records"),
66+
follow_up_test_inputs.to_dict(orient="records"),
67+
test_inputs.to_dict(orient="records") if not test_inputs.empty
68+
else [{}] * len(source_test_inputs),
69+
[self.output_var] * len(source_test_inputs),
70+
[str(self)] * len(source_test_inputs)
71+
)
72+
)
1473

1574
@abstractmethod
1675
def test_oracle(self):
@@ -23,16 +82,18 @@ def execute_test(self):
2382
...
2483

2584

26-
@dataclass(order=True, frozen=True)
85+
@dataclass(order=True)
2786
class ShouldCause(MetamorphicRelation):
2887
"""Class representing a should cause metamorphic relation."""
2988

30-
def generate_follow_up(self, source_input_configuration):
31-
pass
32-
3389
def test_oracle(self):
3490
pass
3591

3692
def execute_test(self):
3793
pass
3894

95+
def __str__(self):
96+
formatted_str = f"{self.treatment_var} --> {self.output_var}"
97+
if self.adjustment_vars:
98+
formatted_str += f" | {self.adjustment_vars}"
99+
return formatted_str
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
import os
3+
4+
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
5+
from causal_testing.specification.causal_dag import CausalDAG
6+
from causal_testing.specification.metamorphic_relation import ShouldCause
7+
8+
class TestMetamorphicRelation(unittest.TestCase):
9+
10+
def setUp(self) -> None:
11+
temp_dir_path = create_temp_dir_if_non_existent()
12+
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
13+
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
14+
with open(self.dag_dot_path, "w") as f:
15+
f.write(dag_dot)
16+
17+
def test_metamorphic_relation(self):
18+
causal_dag = CausalDAG(self.dag_dot_path)
19+
for edge in causal_dag.graph.edges:
20+
(u, v) = edge
21+
should_cause_MR = ShouldCause(u, v, None, causal_dag)
22+
should_cause_MR.generate_follow_up(1, -10.0, 10.0, 1)
23+
print(should_cause_MR.tests)

0 commit comments

Comments
 (0)