Skip to content

Commit bf8b67a

Browse files
Move mutates dictionary
1 parent 0900f47 commit bf8b67a

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

examples/poisson/causal_test_setup.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
CausalTestResult
1111
from causal_testing.json_front.json_class import JsonUtility
1212
from causal_testing.testing.estimators import Estimator
13-
13+
from causal_testing.specification.scenario import Scenario
14+
from causal_testing.specification.variable import Input, Output, Meta
1415

1516
class WidthHeightEstimator(LinearRegressionEstimator):
1617
"""
@@ -113,11 +114,12 @@ def get_args() -> argparse.Namespace:
113114
{"name": "height", "type": float, "distribution": "uniform"},
114115
{"name": "intensity", "type": float, "distribution": "uniform"}
115116
]
117+
116118
outputs = [
117119
{"name": "num_lines_abs", "type": float},
118120
{"name": "num_shapes_abs", "type": float}
119-
120121
]
122+
121123
metas = [
122124
{"name": "num_lines_unit", "type": float, "populate": "populate_num_lines_unit"},
123125
{"name": "num_shapes_unit", "type": float, "populate": "populate_num_shapes_unit"},
@@ -151,6 +153,24 @@ def get_args() -> argparse.Namespace:
151153
}
152154

153155

156+
# Create input structure required to create a modelling scenario
157+
modelling_inputs = [Input(i['name'], i['type'], distributions[i['distribution']]) for i in inputs] +\
158+
[Output(i['name'], i['type']) for i in outputs] +\
159+
[Meta(i['name'], i['type'], populates[i['populate']]) for i in metas] if metas else list()
160+
161+
# Create modelling scenario to access z3 variable mirrors
162+
modelling_scenario = Scenario(modelling_inputs, None)
163+
modelling_scenario.setup_treatment_variables()
164+
165+
mutates = {
166+
"Increase": lambda x: modelling_scenario.treatment_variables[x].z3 >
167+
modelling_scenario.variables[x].z3,
168+
"ChangeByFactor(2)": lambda x: modelling_scenario.treatment_variables[x].z3 ==
169+
modelling_scenario.variables[
170+
x].z3 * 2
171+
}
172+
173+
154174
class MyJsonUtility(JsonUtility):
155175
"""Extension of JsonUtility class to add modelling assumptions to the estimator instance"""
156176

@@ -167,18 +187,10 @@ def add_modelling_assumptions(self, estimation_model: Estimator):
167187
args = get_args()
168188

169189
json_utility = MyJsonUtility() # Create an instance of the extended JsonUtility class
170-
json_utility.set_path(args.directory_path) # Set the path to the data.csv, dag.dot and causal_tests.json file
190+
json_utility.set_path(args.directory_path) # Set the path to the data.csv, dag.dot and causal_tests.json file
171191

172192
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
173193
json_utility.set_variables(inputs, outputs, metas, distributions, populates)
174194
json_utility.setup() # Sets up all the necessary parts of the json_class needed to execute tests
175195

176-
mutates = {
177-
"Increase": lambda x: json_utility.modelling_scenario.treatment_variables[x].z3 >
178-
json_utility.modelling_scenario.variables[x].z3,
179-
"ChangeByFactor(2)": lambda x: json_utility.modelling_scenario.treatment_variables[x].z3 ==
180-
json_utility.modelling_scenario.variables[
181-
x].z3 * 2
182-
}
183-
184196
json_utility.execute_tests(effects, mutates, estimators, args.f)

0 commit comments

Comments
 (0)