26
26
logger = logging .getLogger (__name__ )
27
27
28
28
29
- class JsonUtility ( ABC ) :
29
+ class JsonUtility :
30
30
"""
31
31
The JsonUtility Class provides the functionality to use structured JSON to setup and run causal tests on the
32
32
CausalTestingFramework.
@@ -48,7 +48,7 @@ def __init__(self, log_path):
48
48
self .variables = None
49
49
self .data = []
50
50
self .test_plan = None
51
- self .modelling_scenario = None
51
+ self .scenario = None
52
52
self .causal_specification = None
53
53
self .setup_logger (log_path )
54
54
@@ -61,36 +61,27 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
61
61
"""
62
62
self .paths = JsonClassPaths (json_path = json_path , dag_path = dag_path , data_paths = data_paths )
63
63
64
- def set_variables (self , inputs : list [dict ], outputs : list [dict ], metas : list [dict ]):
65
- """Populate the Causal Variables
66
- :param inputs:
67
- :param outputs:
68
- :param metas:
69
- """
70
-
71
- self .variables = CausalVariables (inputs = inputs , outputs = outputs , metas = metas )
72
-
73
- def setup (self ):
64
+ def setup (self , scenario : Scenario ):
74
65
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
75
- self .modelling_scenario = Scenario ( self . variables . inputs + self . variables . outputs + self . variables . metas , None )
76
- self .modelling_scenario .setup_treatment_variables ()
66
+ self .scenario = scenario
67
+ self .scenario .setup_treatment_variables ()
77
68
self .causal_specification = CausalSpecification (
78
- scenario = self .modelling_scenario , causal_dag = CausalDAG (self .paths .dag_path )
69
+ scenario = self .scenario , causal_dag = CausalDAG (self .paths .dag_path )
79
70
)
80
71
self ._json_parse ()
81
72
self ._populate_metas ()
82
73
83
74
def _create_abstract_test_case (self , test , mutates , effects ):
84
75
assert len (test ["mutations" ]) == 1
85
76
abstract_test = AbstractCausalTestCase (
86
- scenario = self .modelling_scenario ,
77
+ scenario = self .scenario ,
87
78
intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
88
- treatment_variable = next (self .modelling_scenario .variables [v ] for v in test ["mutations" ]),
79
+ treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
89
80
expected_causal_effect = {
90
- self .modelling_scenario .variables [variable ]: effects [effect ]
81
+ self .scenario .variables [variable ]: effects [effect ]
91
82
for variable , effect in test ["expectedEffect" ].items ()
92
83
},
93
- effect_modifiers = {self .modelling_scenario .variables [v ] for v in test ["effect_modifiers" ]}
84
+ effect_modifiers = {self .scenario .variables [v ] for v in test ["effect_modifiers" ]}
94
85
if "effect_modifiers" in test
95
86
else {},
96
87
estimate_type = test ["estimate_type" ],
@@ -141,10 +132,9 @@ def _populate_metas(self):
141
132
"""
142
133
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
143
134
"""
144
- for meta in self .variables . metas :
135
+ for meta in self .scenario . variables_of_type ( Meta ) :
145
136
meta .populate (self .data )
146
-
147
- for var in self .variables .metas + self .variables .outputs :
137
+ for var in self .scenario .variables_of_type (Meta ).union (self .scenario .variables_of_type (Output )):
148
138
if not var .distribution :
149
139
fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
150
140
fitter .fit ()
@@ -195,7 +185,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
195
185
- estimation_model - Estimator instance for the test being run
196
186
"""
197
187
198
- data_collector = ObservationalDataCollector (self .modelling_scenario , self .data )
188
+ data_collector = ObservationalDataCollector (self .scenario , self .data )
199
189
causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
200
190
201
191
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
@@ -289,17 +279,16 @@ def __init__(self, json_path: str, dag_path: str, data_paths: str):
289
279
self .data_paths = [Path (path ) for path in data_paths ]
290
280
291
281
292
- @dataclass ()
293
282
class CausalVariables :
294
283
"""
295
284
A dataclass that converts
296
285
"""
297
286
298
- inputs : list [Input ]
299
- outputs : list [Output ]
300
- metas : list [Meta ]
301
-
302
287
def __init__ (self , inputs : list [dict ], outputs : list [dict ], metas : list [dict ]):
303
288
self .inputs = [Input (** i ) for i in inputs ]
304
289
self .outputs = [Output (** o ) for o in outputs ]
305
290
self .metas = [Meta (** m ) for m in metas ] if metas else []
291
+
292
+ def __iter__ (self ):
293
+ for var in self .inputs + self .outputs + self .metas :
294
+ yield var
0 commit comments