@@ -139,8 +139,8 @@ def predict_qte(
139139
140140 qte_var = qtes .var (axis = 0 )
141141
142- qte_lower = qte + norm .ppf (alpha / 2 ) / np .sqrt (qte_var )
143- qte_upper = qte + norm .ppf (1 - alpha / 2 ) / np .sqrt (qte_var )
142+ qte_lower = qte + norm .ppf (alpha / 2 ) * np .sqrt (qte_var )
143+ qte_upper = qte + norm .ppf (1 - alpha / 2 ) * np .sqrt (qte_var )
144144
145145 return qte , qte_lower , qte_upper
146146
@@ -155,14 +155,14 @@ def _compute_dtes(
155155 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
156156 """Compute expected DTEs."""
157157 treatment_cdf , treatment_cdf_mat = self ._compute_cumulative_distribution (
158- np . full ( locations . shape , target_treatment_arm ) ,
158+ target_treatment_arm ,
159159 locations ,
160160 self .confoundings ,
161161 self .treatment_arms ,
162162 self .outcomes ,
163163 )
164164 control_cdf , control_cdf_mat = self ._compute_cumulative_distribution (
165- np . full ( locations . shape , control_treatment_arm ) ,
165+ control_treatment_arm ,
166166 locations ,
167167 self .confoundings ,
168168 self .treatment_arms ,
@@ -207,7 +207,7 @@ def _compute_ptes(
207207 """Compute expected PTEs."""
208208 treatment_cumulative_pre , treatment_cdf_mat_pre = (
209209 self ._compute_cumulative_distribution (
210- np . full ( locations . shape , target_treatment_arm ) ,
210+ target_treatment_arm ,
211211 locations ,
212212 self .confoundings ,
213213 self .treatment_arms ,
@@ -216,7 +216,7 @@ def _compute_ptes(
216216 )
217217 treatment_cumulative_post , treatment_cdf_mat_post = (
218218 self ._compute_cumulative_distribution (
219- np . full ( locations . shape , target_treatment_arm ) ,
219+ target_treatment_arm ,
220220 locations + width ,
221221 self .confoundings ,
222222 self .treatment_arms ,
@@ -226,7 +226,7 @@ def _compute_ptes(
226226 treatment_pdf = treatment_cumulative_post - treatment_cumulative_pre
227227 control_cumulative_pre , control_cdf_mat_pre = (
228228 self ._compute_cumulative_distribution (
229- np . full ( locations . shape , control_treatment_arm ) ,
229+ control_treatment_arm ,
230230 locations ,
231231 self .confoundings ,
232232 self .treatment_arms ,
@@ -235,7 +235,7 @@ def _compute_ptes(
235235 )
236236 control_cumulative_post , control_cdf_mat_post = (
237237 self ._compute_cumulative_distribution (
238- np . full ( locations . shape , control_treatment_arm ) ,
238+ control_treatment_arm ,
239239 locations + width ,
240240 self .confoundings ,
241241 self .treatment_arms ,
@@ -291,7 +291,7 @@ def find_quantile(quantile, arm):
291291 while low <= high :
292292 mid = (low + high ) // 2
293293 val , _ = self ._compute_cumulative_distribution (
294- np . full (( 1 ), arm ) ,
294+ arm ,
295295 np .full ((1 ), locations [mid ]),
296296 confoundings ,
297297 treatment_arms ,
@@ -339,11 +339,11 @@ def fit(
339339
340340 return self
341341
342- def predict (self , treatment_arms : np . ndarray , locations : np .ndarray ) -> np .ndarray :
342+ def predict (self , treatment_arm : int , locations : np .ndarray ) -> np .ndarray :
343343 """Compute cumulative distribution values.
344344
345345 Args:
346- treatment_arms (np.ndarray ): The index of the treatment arm.
346+ treatment_arm (int ): The index of the treatment arm.
347347 outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
348348
349349 Returns:
@@ -354,15 +354,13 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
354354 "This estimator has not been trained yet. Please call fit first"
355355 )
356356
357- unincluded_arms = set (treatment_arms ) - set (self .treatment_arms )
358-
359- if len (unincluded_arms ) > 0 :
357+ if treatment_arm not in self .treatment_arms :
360358 raise ValueError (
361- f"This treatment_arms argument contains arms not included in the training data: { unincluded_arms } "
359+ f"This target treatment arm was not included in the training data: { treatment_arm } "
362360 )
363361
364362 return self ._compute_cumulative_distribution (
365- treatment_arms ,
363+ treatment_arm ,
366364 locations ,
367365 self .confoundings ,
368366 self .treatment_arms ,
@@ -371,7 +369,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
371369
372370 def _compute_cumulative_distribution (
373371 self ,
374- target_treatment_arms : np . ndarray ,
372+ target_treatment_arm : int ,
375373 locations : np .ndarray ,
376374 confoundings : np .ndarray ,
377375 treatment_arms : np .ndarray ,
@@ -396,7 +394,7 @@ def __init__(self):
396394
397395 def _compute_cumulative_distribution (
398396 self ,
399- target_treatment_arms : np . ndarray ,
397+ target_treatment_arm : int ,
400398 locations : np .ndarray ,
401399 confoundings : np .ndarray ,
402400 treatment_arms : np .ndarray ,
@@ -405,7 +403,7 @@ def _compute_cumulative_distribution(
405403 """Compute the cumulative distribution values.
406404
407405 Args:
408- target_treatment_arms (np.ndarray ): The index of the treatment arm.
406+ target_treatment_arm (int ): The index of the treatment arm.
409407 locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
410408 confoundings: (np.ndarray): An array of confounding variables in the observed data.
411409 treatment_arms (np.ndarray): An array of treatment arms in the observed data.
@@ -426,10 +424,10 @@ def _compute_cumulative_distribution(
426424 d_confounding [arm ] = selected_confounding [sorted_indices ]
427425 d_outcome [arm ] = selected_outcome [sorted_indices ]
428426 cumulative_distribution = np .zeros (locations .shape )
429- for i , ( outcome , arm ) in enumerate (zip ( locations , target_treatment_arms ) ):
427+ for i , outcome in enumerate (locations ):
430428 cumulative_distribution [i ] = (
431- np .searchsorted (d_outcome [arm ], outcome , side = "right" )
432- ) / d_outcome [arm ]. shape [ 0 ]
429+ np .searchsorted (d_outcome [target_treatment_arm ], outcome , side = "right" )
430+ ) / len ( d_outcome [target_treatment_arm ])
433431 return cumulative_distribution , np .zeros ((n_obs , n_loc ))
434432
435433
@@ -460,7 +458,7 @@ def __init__(self, base_model, folds=3, is_multi_task=False):
460458
461459 def _compute_cumulative_distribution (
462460 self ,
463- target_treatment_arms : np . ndarray ,
461+ target_treatment_arm : int ,
464462 locations : np .ndarray ,
465463 confoundings : np .ndarray ,
466464 treatment_arms : np .ndarray ,
@@ -469,7 +467,7 @@ def _compute_cumulative_distribution(
469467 """Compute the cumulative distribution values.
470468
471469 Args:
472- target_treatment_arms (np.ndarray ): The index of the treatment arm.
470+ target_treatment_arm (int ): The index of the treatment arm.
473471 locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
474472 confoundings: (np.ndarray): An array of confounding variables in the observed data.
475473 treatment_arm (np.ndarray): An array of treatment arms in the observed data.
@@ -482,8 +480,7 @@ def _compute_cumulative_distribution(
482480 n_loc = locations .shape [0 ]
483481 cumulative_distribution = np .zeros (n_loc )
484482 superset_prediction = np .zeros ((n_records , n_loc ))
485- arm = target_treatment_arms [0 ]
486- treatment_mask = treatment_arms == arm
483+ treatment_mask = treatment_arms == target_treatment_arm
487484 if self .is_multi_task :
488485 confounding_in_arm = confoundings [treatment_mask ]
489486 n_records_in_arm = len (confounding_in_arm )
@@ -512,7 +509,7 @@ def _compute_cumulative_distribution(
512509 cdf - subset_prediction .mean (axis = 0 ) + superset_prediction .mean (axis = 0 )
513510 ) # (n_loc)
514511 else :
515- for i , ( location , arm ) in enumerate (zip ( locations , target_treatment_arms ) ):
512+ for i , location in enumerate (locations ):
516513 confounding_in_arm = confoundings [treatment_mask ]
517514 outcome_in_arm = outcomes [treatment_mask ]
518515 subset_prediction = np .zeros (outcome_in_arm .shape [0 ])
@@ -525,7 +522,7 @@ def _compute_cumulative_distribution(
525522 confounding_train = confoundings [~ subset_mask ]
526523 confounding_fit = confoundings [subset_mask ]
527524 binominal_train = binominal [~ subset_mask ]
528- if np .unique (binominal_train ). shape [ 0 ] == 1 :
525+ if len ( np .unique (binominal_train )) == 1 :
529526 subset_prediction [subset_mask_inner ] = binominal_train [0 ]
530527 superset_prediction [superset_mask , i ] = binominal_train [0 ]
531528 continue
0 commit comments