@@ -139,8 +139,8 @@ def predict_qte(
139139
140140 qte_var = qtes .var (axis = 0 )
141141
142- qte_lower = qte + norm .ppf (alpha / 2 ) / np .sqrt (qte_var )
143- qte_upper = qte + norm .ppf (1 - alpha / 2 ) / np .sqrt (qte_var )
142+ qte_lower = qte + norm .ppf (alpha / 2 ) * np .sqrt (qte_var )
143+ qte_upper = qte + norm .ppf (1 - alpha / 2 ) * np .sqrt (qte_var )
144144
145145 return qte , qte_lower , qte_upper
146146
@@ -155,14 +155,14 @@ def _compute_dtes(
155155 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
156156 """Compute expected DTEs."""
157157 treatment_cdf , treatment_cdf_mat = self ._compute_cumulative_distribution (
158- np . full ( locations . shape , target_treatment_arm ) ,
158+ target_treatment_arm ,
159159 locations ,
160160 self .confoundings ,
161161 self .treatment_arms ,
162162 self .outcomes ,
163163 )
164164 control_cdf , control_cdf_mat = self ._compute_cumulative_distribution (
165- np . full ( locations . shape , control_treatment_arm ) ,
165+ control_treatment_arm ,
166166 locations ,
167167 self .confoundings ,
168168 self .treatment_arms ,
@@ -207,7 +207,7 @@ def _compute_ptes(
207207 """Compute expected PTEs."""
208208 treatment_cumulative_pre , treatment_cdf_mat_pre = (
209209 self ._compute_cumulative_distribution (
210- np . full ( locations . shape , target_treatment_arm ) ,
210+ target_treatment_arm ,
211211 locations ,
212212 self .confoundings ,
213213 self .treatment_arms ,
@@ -216,7 +216,7 @@ def _compute_ptes(
216216 )
217217 treatment_cumulative_post , treatment_cdf_mat_post = (
218218 self ._compute_cumulative_distribution (
219- np . full ( locations . shape , target_treatment_arm ) ,
219+ target_treatment_arm ,
220220 locations + width ,
221221 self .confoundings ,
222222 self .treatment_arms ,
@@ -226,7 +226,7 @@ def _compute_ptes(
226226 treatment_pdf = treatment_cumulative_post - treatment_cumulative_pre
227227 control_cumulative_pre , control_cdf_mat_pre = (
228228 self ._compute_cumulative_distribution (
229- np . full ( locations . shape , control_treatment_arm ) ,
229+ control_treatment_arm ,
230230 locations ,
231231 self .confoundings ,
232232 self .treatment_arms ,
@@ -235,7 +235,7 @@ def _compute_ptes(
235235 )
236236 control_cumulative_post , control_cdf_mat_post = (
237237 self ._compute_cumulative_distribution (
238- np . full ( locations . shape , control_treatment_arm ) ,
238+ control_treatment_arm ,
239239 locations + width ,
240240 self .confoundings ,
241241 self .treatment_arms ,
@@ -291,7 +291,7 @@ def find_quantile(quantile, arm):
291291 while low <= high :
292292 mid = (low + high ) // 2
293293 val , _ = self ._compute_cumulative_distribution (
294- np . full (( 1 ), arm ) ,
294+ arm ,
295295 np .full ((1 ), locations [mid ]),
296296 confoundings ,
297297 treatment_arms ,
@@ -339,11 +339,11 @@ def fit(
339339
340340 return self
341341
342- def predict (self , treatment_arms : np . ndarray , locations : np .ndarray ) -> np .ndarray :
342+ def predict (self , treatment_arm : int , locations : np .ndarray ) -> np .ndarray :
343343 """Compute cumulative distribution values.
344344
345345 Args:
346- treatment_arms (np.ndarray ): The index of the treatment arm.
346+ treatment_arm (int ): The index of the treatment arm.
347347 outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
348348
349349 Returns:
@@ -354,15 +354,13 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
354354 "This estimator has not been trained yet. Please call fit first"
355355 )
356356
357- unincluded_arms = set (treatment_arms ) - set (self .treatment_arms )
358-
359- if len (unincluded_arms ) > 0 :
357+ if treatment_arm not in self .treatment_arms :
360358 raise ValueError (
361- f"This treatment_arms argument contains arms not included in the training data: { unincluded_arms } "
359+ f"This target treatment arm was not included in the training data: { treatment_arm } "
362360 )
363361
364362 return self ._compute_cumulative_distribution (
365- treatment_arms ,
363+ treatment_arm ,
366364 locations ,
367365 self .confoundings ,
368366 self .treatment_arms ,
@@ -371,7 +369,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
371369
372370 def _compute_cumulative_distribution (
373371 self ,
374- target_treatment_arms : np . ndarray ,
372+ target_treatment_arm : int ,
375373 locations : np .ndarray ,
376374 confoundings : np .ndarray ,
377375 treatment_arms : np .ndarray ,
@@ -396,7 +394,7 @@ def __init__(self):
396394
397395 def _compute_cumulative_distribution (
398396 self ,
399- target_treatment_arms : np . ndarray ,
397+ target_treatment_arm : int ,
400398 locations : np .ndarray ,
401399 confoundings : np .ndarray ,
402400 treatment_arms : np .ndarray ,
@@ -405,7 +403,7 @@ def _compute_cumulative_distribution(
405403 """Compute the cumulative distribution values.
406404
407405 Args:
408- target_treatment_arms (np.ndarray ): The index of the treatment arm.
406+ target_treatment_arm (int ): The index of the treatment arm.
409407 locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
410408 confoundings: (np.ndarray): An array of confounding variables in the observed data.
411409 treatment_arms (np.ndarray): An array of treatment arms in the observed data.
@@ -426,22 +424,23 @@ def _compute_cumulative_distribution(
426424 d_confounding [arm ] = selected_confounding [sorted_indices ]
427425 d_outcome [arm ] = selected_outcome [sorted_indices ]
428426 cumulative_distribution = np .zeros (locations .shape )
429- for i , ( outcome , arm ) in enumerate (zip ( locations , target_treatment_arms ) ):
427+ for i , outcome in enumerate (locations ):
430428 cumulative_distribution [i ] = (
431- np .searchsorted (d_outcome [arm ], outcome , side = "right" )
432- ) / d_outcome [arm ]. shape [ 0 ]
429+ np .searchsorted (d_outcome [target_treatment_arm ], outcome , side = "right" )
430+ ) / len ( d_outcome [target_treatment_arm ])
433431 return cumulative_distribution , np .zeros ((n_obs , n_loc ))
434432
435433
436434class AdjustedDistributionEstimator (DistributionEstimatorBase ):
437435 """A class is for estimating the adjusted distribution function and computing the Distributional parameters based on the trained conditional estimator."""
438436
439- def __init__ (self , base_model , folds = 3 ):
437+ def __init__ (self , base_model , folds = 3 , is_multi_task = False ):
440438 """Initializes the AdjustedDistributionEstimator.
441439
442440 Args:
443441 base_model (scikit-learn estimator): The base model implementing used for conditional distribution function estimators. The model should implement fit(data, targets) and predict_proba(data).
444442 folds (int): The number of folds for cross-fitting.
443+ is_multi_task(bool): Whether to use multi-task learning. If True, your base model needs to support multi-task prediction (n_samples, n_features) -> (n_samples, n_targets).
445444
446445 Returns:
447446 AdjustedDistributionEstimator: An instance of the estimator.
@@ -454,11 +453,12 @@ def __init__(self, base_model, folds=3):
454453 )
455454 self .base_model = base_model
456455 self .folds = folds
456+ self .is_multi_task = is_multi_task
457457 super ().__init__ ()
458458
459459 def _compute_cumulative_distribution (
460460 self ,
461- target_treatment_arms : np . ndarray ,
461+ target_treatment_arm : int ,
462462 locations : np .ndarray ,
463463 confoundings : np .ndarray ,
464464 treatment_arms : np .ndarray ,
@@ -467,7 +467,7 @@ def _compute_cumulative_distribution(
467467 """Compute the cumulative distribution values.
468468
469469 Args:
470- target_treatment_arms (np.ndarray ): The index of the treatment arm.
470+ target_treatment_arm (int ): The index of the treatment arm.
471471 locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
472472 confoundings: (np.ndarray): An array of confounding variables in the observed data.
473473 treatment_arm (np.ndarray): An array of treatment arms in the observed data.
@@ -476,43 +476,75 @@ def _compute_cumulative_distribution(
476476 Returns:
477477 np.ndarray: Estimated cumulative distribution values.
478478 """
479- n_obs = outcomes .shape [0 ]
479+ n_records = outcomes .shape [0 ]
480480 n_loc = locations .shape [0 ]
481- cumulative_distribution = np .zeros (locations .shape )
482- superset_prediction = np .zeros ((n_obs , n_loc ))
483- for i , (location , arm ) in enumerate (zip (locations , target_treatment_arms )):
484- confounding_in_arm = confoundings [treatment_arms == arm ]
485- outcome_in_arm = outcomes [treatment_arms == arm ]
486- subset_prediction = np .zeros (outcome_in_arm .shape [0 ])
487- binominal = (outcome_in_arm <= location ) * 1
488- cdf = binominal .mean ()
481+ cumulative_distribution = np .zeros (n_loc )
482+ superset_prediction = np .zeros ((n_records , n_loc ))
483+ treatment_mask = treatment_arms == target_treatment_arm
484+ if self .is_multi_task :
485+ confounding_in_arm = confoundings [treatment_mask ]
486+ n_records_in_arm = len (confounding_in_arm )
487+ outcome_in_arm = outcomes [treatment_mask ] # (n_records)
488+ subset_prediction = np .zeros (
489+ (n_records_in_arm , n_loc )
490+ ) # (n_records_in_arm, n_loc)
491+ binominal = (outcomes .reshape (- 1 , 1 ) <= locations ) * 1 # (n_records, n_loc)
492+ cdf = binominal [treatment_mask ].mean (axis = 0 ) # (n_loc)
489493 for fold in range (self .folds ):
490- subset_mask = (
491- np . arange ( confounding_in_arm . shape [ 0 ]) % self . folds == fold
492- )
493- confounding_train = confounding_in_arm [~ subset_mask ]
494- confounding_fit = confounding_in_arm [subset_mask ]
494+ superset_mask = np . arange ( n_records ) % self . folds == fold
495+ subset_mask = superset_mask & treatment_mask
496+ subset_mask_inner = superset_mask [ treatment_mask ]
497+ confounding_train = confoundings [~ subset_mask ]
498+ confounding_fit = confoundings [subset_mask ]
495499 binominal_train = binominal [~ subset_mask ]
496- superset_mask = np .arange (self .outcomes .shape [0 ]) % self .folds == fold
497- if np .unique (binominal_train ).shape [0 ] == 1 :
498- subset_prediction [subset_mask ] = binominal_train [0 ]
499- superset_prediction [superset_mask , i ] = binominal_train [0 ]
500- continue
501500 model = deepcopy (self .base_model )
502501 model .fit (confounding_train , binominal_train )
503- subset_prediction [subset_mask ] = self ._compute_model_prediction (
502+ subset_prediction [subset_mask_inner ] = self ._compute_model_prediction (
504503 model , confounding_fit
505504 )
506- superset_prediction [superset_mask , i ] = self ._compute_model_prediction (
505+ superset_prediction [superset_mask ] = self ._compute_model_prediction (
507506 model , confoundings [superset_mask ]
508507 )
509- cumulative_distribution [i ] = (
510- cdf - subset_prediction .mean () + superset_prediction [:, i ].mean ()
511- )
508+ cumulative_distribution = (
509+ cdf - subset_prediction .mean (axis = 0 ) + superset_prediction .mean (axis = 0 )
510+ ) # (n_loc)
511+ else :
512+ for i , location in enumerate (locations ):
513+ confounding_in_arm = confoundings [treatment_mask ]
514+ outcome_in_arm = outcomes [treatment_mask ]
515+ subset_prediction = np .zeros (outcome_in_arm .shape [0 ])
516+ binominal = (outcomes <= location ) * 1 # (n_records)
517+ cdf = binominal [treatment_mask ].mean ()
518+ for fold in range (self .folds ):
519+ superset_mask = np .arange (n_records ) % self .folds == fold
520+ subset_mask = superset_mask & treatment_mask
521+ subset_mask_inner = superset_mask [treatment_mask ]
522+ confounding_train = confoundings [~ subset_mask ]
523+ confounding_fit = confoundings [subset_mask ]
524+ binominal_train = binominal [~ subset_mask ]
525+ if len (np .unique (binominal_train )) == 1 :
526+ subset_prediction [subset_mask_inner ] = binominal_train [0 ]
527+ superset_prediction [superset_mask , i ] = binominal_train [0 ]
528+ continue
529+ model = deepcopy (self .base_model )
530+ model .fit (confounding_train , binominal_train )
531+ subset_prediction [subset_mask_inner ] = (
532+ self ._compute_model_prediction (model , confounding_fit )
533+ )
534+ superset_prediction [superset_mask , i ] = (
535+ self ._compute_model_prediction (
536+ model , confoundings [superset_mask ]
537+ )
538+ )
539+ cumulative_distribution [i ] = (
540+ cdf - subset_prediction .mean () + superset_prediction [:, i ].mean ()
541+ )
512542 return cumulative_distribution , superset_prediction
513543
514544 def _compute_model_prediction (self , model , confoundings : np .ndarray ) -> np .ndarray :
515545 if hasattr (model , "predict_proba" ):
546+ if self .is_multi_task :
547+ return model .predict_proba (confoundings )
516548 return model .predict_proba (confoundings )[:, 1 ]
517549 else :
518550 return model .predict (confoundings )
0 commit comments