|
9 | 9 | from causal_testing.testing.causal_test_outcome import ExactValue, Positive, Negative, NoEffect, CausalTestOutcome, \
|
10 | 10 | CausalTestResult
|
11 | 11 | from causal_testing.json_front.json_class import JsonUtility
|
| 12 | +from causal_testing.testing.estimators import Estimator |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class WidthHeightEstimator(LinearRegressionEstimator):
|
@@ -153,13 +154,13 @@ def get_args() -> argparse.Namespace:
|
153 | 154 | class MyJsonUtility(JsonUtility):
|
154 | 155 | """Extension of JsonUtility class to add modelling assumptions to the estimator instance"""
|
155 | 156 |
|
156 |
| - def add_modelling_assumptions(self, estimator: LinearRegressionEstimator): |
| 157 | + def add_modelling_assumptions(self, estimation_model: Estimator): |
157 | 158 | # Add squared intensity term as a modelling assumption if intensity is the treatment of the test
|
158 |
| - if "intensity" in estimator.treatment[0]: |
159 |
| - estimator.add_squared_term_to_df("intensity") |
160 |
| - if isinstance(estimator, WidthHeightEstimator): |
161 |
| - estimator.add_product_term_to_df("width", "intensity") |
162 |
| - estimator.add_product_term_to_df("height", "intensity") |
| 159 | + if "intensity" in estimation_model.treatment[0]: |
| 160 | + estimation_model.add_squared_term_to_df("intensity") |
| 161 | + if isinstance(estimation_model, WidthHeightEstimator): |
| 162 | + estimation_model.add_product_term_to_df("width", "intensity") |
| 163 | + estimation_model.add_product_term_to_df("height", "intensity") |
163 | 164 |
|
164 | 165 |
|
165 | 166 | if __name__ == "__main__":
|
|
0 commit comments