@@ -79,14 +79,20 @@ class LinearRegressionEstimator(Estimator):
79
79
combination of parameters and functions of the variables (note these functions need not be linear).
80
80
"""
81
81
def __init__ (self , treatment : tuple , treatment_values : float , control_values : float , adjustment_set : set ,
82
- outcome : tuple , df : pd .DataFrame = None , effect_modifiers : dict [Variable : Any ] = None , product_terms : list [tuple [Variable , Variable ]] = None ):
82
+ outcome : tuple , df : pd .DataFrame = None , effect_modifiers : dict [Variable : Any ] = None , product_terms : list [tuple [Variable , Variable ]] = None , intercept : int = 1 ):
83
83
super ().__init__ (treatment , treatment_values , control_values , adjustment_set , outcome , df , effect_modifiers )
84
+
84
85
if product_terms is None :
85
86
product_terms = []
86
87
for (term_a , term_b ) in product_terms :
87
88
self .add_product_term_to_df (term_a , term_b )
89
+ for term in self .effect_modifiers :
90
+ self .adjustment_set .add (term )
91
+
92
+ self .product_terms = product_terms
88
93
self .square_terms = []
89
- self .product_terms = []
94
+ self .inverse_terms = []
95
+ self .intercept = intercept
90
96
91
97
def add_modelling_assumptions (self ):
92
98
"""
@@ -112,6 +118,21 @@ def add_squared_term_to_df(self, term_to_square: str):
112
118
f'with { term_to_square } .'
113
119
self .square_terms .append (term_to_square )
114
120
121
+ def add_inverse_term_to_df (self , term_to_invert : str ):
122
+ """ Add an inverse term to the linear regression model and df.
123
+
124
+ This enables the user to capture curvilinear relationships with a linear regression model, not just straight
125
+ lines, while automatically adding the modelling assumption imposed by the addition of this term.
126
+
127
+ :param term_to_square: The term (column in data and variable in DAG) which is to be squared.
128
+ """
129
+ new_term = "1/" + str (term_to_invert )
130
+ self .df [new_term ] = 1 / self .df [term_to_invert ]
131
+ self .adjustment_set .add (new_term )
132
+ self .modelling_assumptions += f'Relationship between { self .treatment } and { self .outcome } varies inversely' \
133
+ f'with { term_to_invert } .'
134
+ self .inverse_terms .append (term_to_invert )
135
+
115
136
def add_product_term_to_df (self , term_a : str , term_b : str ):
116
137
""" Add a product term to the linear regression model and df.
117
138
@@ -146,6 +167,7 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
146
167
:return: The average treatment effect and the 95% Wald confidence intervals.
147
168
"""
148
169
model = self ._run_linear_regression ()
170
+ print (model .summary ())
149
171
# Create an empty individual for the control and treated
150
172
individuals = pd .DataFrame (1 , index = ['control' , 'treated' ], columns = model .params .index )
151
173
individuals .loc ['control' , list (self .treatment )] = self .control_values
@@ -162,30 +184,57 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
162
184
confidence_intervals = list (t_test_results .conf_int ().flatten ())
163
185
return ate , confidence_intervals
164
186
165
- def estimate_risk_ratio (self ) -> tuple [float , list [float , float ]]:
166
- """ Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
167
- by changing the treatment variable from the control value to the treatment value.
187
+ def estimate_control_treatment (self ) -> tuple [pd .Series , pd .Series ]:
188
+ """ Estimate the outcomes under control and treatment.
168
189
169
190
:return: The average treatment effect and the 95% Wald confidence intervals.
170
191
"""
171
192
model = self ._run_linear_regression ()
193
+ self .model = model
172
194
173
195
x = pd .DataFrame ()
174
196
x [self .treatment [0 ]] = [self .treatment_values , self .control_values ]
175
- x ['Intercept' ] = 1
197
+ x ['Intercept' ] = self .intercept
198
+ for k , v in self .effect_modifiers .items ():
199
+ x [k ] = v
176
200
for t in self .square_terms :
177
201
x [t + '^2' ] = x [t ] ** 2
202
+ for t in self .inverse_terms :
203
+ x ['1/' + t ] = 1 / x [t ]
178
204
for a , b in self .product_terms :
179
205
x [f"{ a } *{ b } " ] = x [a ] * x [b ]
206
+ x = x [model .params .index ]
180
207
181
- print ( x )
182
- print ( model . summary ())
208
+ y = model . get_prediction ( x ). summary_frame ( )
209
+ return y . iloc [ 1 ], y . iloc [ 0 ]
183
210
184
- y = model .predict (x )
185
- treatment_outcome = y .iloc [0 ]
186
- control_outcome = y .iloc [1 ]
187
211
188
- return treatment_outcome / control_outcome , None
212
+ def estimate_risk_ratio (self ) -> tuple [float , list [float , float ]]:
213
+ """ Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
214
+ by changing the treatment variable from the control value to the treatment value.
215
+
216
+ :return: The average treatment effect and the 95% Wald confidence intervals.
217
+ """
218
+ control_outcome , treatment_outcome = self .estimate_control_treatment ()
219
+ ci_low = treatment_outcome ['mean_ci_lower' ] / control_outcome ['mean_ci_upper' ]
220
+ ci_high = treatment_outcome ['mean_ci_upper' ] / control_outcome ['mean_ci_lower' ]
221
+
222
+ return (treatment_outcome ['mean' ]/ control_outcome ['mean' ]), [ci_low , ci_high ]
223
+
224
+
225
+ def estimate_ate_calculated (self ) -> tuple [float , list [float , float ]]:
226
+ """ Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
227
+ by changing the treatment variable from the control value to the treatment value. Here, we actually
228
+ calculate the expected outcomes under control and treatment and take one away from the other. This
229
+ allows for custom terms to be put in such as squares, inverses, products, etc.
230
+
231
+ :return: The average treatment effect and the 95% Wald confidence intervals.
232
+ """
233
+ control_outcome , treatment_outcome = self .estimate_control_treatment ()
234
+ ci_low = treatment_outcome ['mean_ci_lower' ] - control_outcome ['mean_ci_upper' ]
235
+ ci_high = treatment_outcome ['mean_ci_upper' ] - control_outcome ['mean_ci_lower' ]
236
+
237
+ return (treatment_outcome ['mean' ]- control_outcome ['mean' ]), [ci_low , ci_high ]
189
238
190
239
def estimate_cates (self ) -> tuple [float , list [float , float ]]:
191
240
""" Estimate the conditional average treatment effect of the treatment on the outcome. That is, the change
@@ -196,7 +245,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
196
245
assert self .effect_modifiers , f"Must have at least one effect modifier to compute CATE - { self .effect_modifiers } ."
197
246
x = pd .DataFrame ()
198
247
x [self .treatment [0 ]] = [self .treatment_values , self .control_values ]
199
- x ['Intercept' ] = 1
248
+ x ['Intercept' ] = self . intercept
200
249
for k , v in self .effect_modifiers .items ():
201
250
self .adjustment_set .add (k )
202
251
x [k ] = v
@@ -226,11 +275,11 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
226
275
necessary_cols = list (self .treatment ) + list (self .adjustment_set ) + list (self .outcome )
227
276
missing_rows = reduced_df [necessary_cols ].isnull ().any (axis = 1 )
228
277
reduced_df = reduced_df [~ missing_rows ]
278
+ reduced_df = reduced_df .sort_values (list (self .treatment ))
229
279
logger .debug (reduced_df [necessary_cols ])
230
280
231
281
# 2. Add intercept
232
- reduced_df ['Intercept' ] = 1
233
-
282
+ reduced_df ['Intercept' ] = self .intercept
234
283
235
284
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
236
285
cols = list (self .treatment )
@@ -289,6 +338,7 @@ def estimate_ate(self) -> float:
289
338
model .fit (outcome_df , treatment_df , X = effect_modifier_df , W = confounders_df )
290
339
291
340
# Obtain the ATE and 95% confidence intervals
341
+ print (dir (model ))
292
342
ate = model .ate (effect_modifier_df , T0 = self .control_values , T1 = self .treatment_values )
293
343
ate_interval = model .ate_interval (effect_modifier_df , T0 = self .control_values , T1 = self .treatment_values )
294
344
ci_low , ci_high = ate_interval [0 ], ate_interval [1 ]
0 commit comments