1
1
from dataclasses import dataclass
2
2
from abc import abstractmethod
3
+ from typing import Iterable
4
+ from itertools import combinations
5
+ import numpy as np
6
+ import pandas as pd
3
7
8
+ from causal_testing .specification .causal_specification import CausalDAG , Node
4
9
5
- @dataclass (order = True , frozen = True )
10
+ @dataclass (order = True )
6
11
class MetamorphicRelation :
7
12
"""Class representing a metamorphic relation."""
13
+ treatment_var : Node
14
+ output_var : Node
15
+ adjustment_vars : Iterable [Node ]
16
+ dag : CausalDAG
17
+ tests : Iterable = None
8
18
9
- @abstractmethod
10
- def generate_follow_up (self , source_input_configuration ):
11
- """Generate a follow-up input configuration from a given source input
12
- configuration."""
13
- ...
19
+ def generate_follow_up (self ,
20
+ n_tests : int ,
21
+ min_val : float ,
22
+ max_val : float ,
23
+ seed : int = 0 ):
24
+ """Generate numerical follow-up input configurations."""
25
+ np .random .seed (seed )
26
+
27
+ # Get set of variables to change, excluding the treatment itself
28
+ variables_to_change = set ([node for node in self .dag .graph .nodes if
29
+ self .dag .graph .in_degree (node ) == 0 ])
30
+ if self .adjustment_vars :
31
+ variables_to_change |= set (self .adjustment_vars )
32
+ if self .treatment_var in variables_to_change :
33
+ variables_to_change .remove (self .treatment_var )
34
+
35
+ # Assign random numerical values to the variables to change
36
+ test_inputs = pd .DataFrame (
37
+ np .random .randint (min_val , max_val ,
38
+ size = (n_tests , len (variables_to_change ))
39
+ ),
40
+ columns = sorted (variables_to_change )
41
+ )
42
+
43
+ # Enumerate the possible source, follow-up pairs for the treatment
44
+ candidate_source_follow_up_pairs = np .array (
45
+ list (combinations (range (int (min_val ), int (max_val + 1 )), 2 ))
46
+ )
47
+
48
+ # Sample without replacement from the possible source, follow-up pairs
49
+ sampled_source_follow_up_indices = np .random .choice (
50
+ candidate_source_follow_up_pairs .shape [0 ], n_tests , replace = False
51
+ )
52
+
53
+ follow_up_input = f"{ self .treatment_var } \' "
54
+ source_follow_up_test_inputs = pd .DataFrame (
55
+ candidate_source_follow_up_pairs [sampled_source_follow_up_indices ],
56
+ columns = sorted ([self .treatment_var ] + [follow_up_input ])
57
+ )
58
+ source_test_inputs = source_follow_up_test_inputs [[self .treatment_var ]]
59
+ follow_up_test_inputs = source_follow_up_test_inputs [[follow_up_input ]]
60
+ follow_up_test_inputs .rename ({follow_up_input : self .treatment_var })
61
+
62
+ # TODO: Add a metamorphic test dataclass that stores these attributes
63
+ self .tests = list (
64
+ zip (
65
+ source_test_inputs .to_dict (orient = "records" ),
66
+ follow_up_test_inputs .to_dict (orient = "records" ),
67
+ test_inputs .to_dict (orient = "records" ) if not test_inputs .empty
68
+ else [{}] * len (source_test_inputs ),
69
+ [self .output_var ] * len (source_test_inputs ),
70
+ [str (self )] * len (source_test_inputs )
71
+ )
72
+ )
14
73
15
74
@abstractmethod
16
75
def test_oracle (self ):
@@ -23,16 +82,18 @@ def execute_test(self):
23
82
...
24
83
25
84
26
- @dataclass (order = True , frozen = True )
85
+ @dataclass (order = True )
27
86
class ShouldCause (MetamorphicRelation ):
28
87
"""Class representing a should cause metamorphic relation."""
29
88
30
- def generate_follow_up (self , source_input_configuration ):
31
- pass
32
-
33
89
def test_oracle (self ):
34
90
pass
35
91
36
92
def execute_test (self ):
37
93
pass
38
94
95
+ def __str__ (self ):
96
+ formatted_str = f"{ self .treatment_var } --> { self .output_var } "
97
+ if self .adjustment_vars :
98
+ formatted_str += f" | { self .adjustment_vars } "
99
+ return formatted_str
0 commit comments