Skip to content

Commit cf31d38

Browse files
authored
Merge pull request #19 from CyberAgentAILab/feat/multi-task
Support multi-task learning
2 parents 1d8043a + 85fc069 commit cf31d38

File tree

8 files changed

+235
-194
lines changed

8 files changed

+235
-194
lines changed

docs/source/_static/qte.png

-27.4 KB
Loading

docs/source/get_started.rst

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ Then, let's build an empirical cumulative distribution function (CDF).
5959
6060
estimator = dte_adj.SimpleDistributionEstimator()
6161
estimator.fit(X, D, Y)
62-
cdf = estimator.predict(D, Y)
62+
locations = np.linspace(Y.min(), Y.max(), 20)
63+
cdf = estimator.predict(1, locations)
6364
6465
Distributional treatment effect (DTE) can be computed easily in the following code.
6566

6667
.. code-block:: python
6768
68-
locations = np.linspace(Y.min(), Y.max(), 20)
6969
dte, lower_bound, upper_bound = estimator.predict_dte(target_treatment_arm=1, control_treatment_arm=0, locations=locations, variance_type="simple")
7070
7171
A convenience function is available to visualize distribution effects. This method can be used for other distribution parameters including Probability Treatment Effect (PTE) and Quantile Treatment Effect (QTE).
@@ -89,7 +89,7 @@ In the following example, we use Logistic Regression. Please make sure that your
8989
logit = LogisticRegression()
9090
estimator = dte_adj.AdjustedDistributionEstimator(logit, folds=3)
9191
estimator.fit(X, D, Y)
92-
cdf = estimator.predict(D, Y)
92+
cdf = estimator.predict(1, locations)
9393
9494
DTE can be computed and visualized in the following code.
9595

@@ -155,4 +155,13 @@ To compute QTE, we use "predict_qte" method. The confidence band is computed by
155155
:alt: QTE of adjusted estimator
156156
:height: 300px
157157
:width: 450px
158-
:align: center
158+
:align: center
159+
160+
You can use any model with "predict_proba" or "predict" method to adjust the distribution function estimation. For example, the following code use XGBoost classifier to estimate the conditional distribution.
161+
162+
.. code-block:: python
163+
164+
import xgboost as xgb
165+
estimator = dte_adj.AdjustedDistributionEstimator(xgb.XGBClassifier(), folds=3)
166+
estimator.fit(X, D, Y)
167+
cdf = estimator.predict(1, locations)

dte_adj/__init__.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def predict_qte(
139139

140140
qte_var = qtes.var(axis=0)
141141

142-
qte_lower = qte + norm.ppf(alpha / 2) / np.sqrt(qte_var)
143-
qte_upper = qte + norm.ppf(1 - alpha / 2) / np.sqrt(qte_var)
142+
qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
143+
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)
144144

145145
return qte, qte_lower, qte_upper
146146

@@ -155,14 +155,14 @@ def _compute_dtes(
155155
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
156156
"""Compute expected DTEs."""
157157
treatment_cdf, treatment_cdf_mat = self._compute_cumulative_distribution(
158-
np.full(locations.shape, target_treatment_arm),
158+
target_treatment_arm,
159159
locations,
160160
self.confoundings,
161161
self.treatment_arms,
162162
self.outcomes,
163163
)
164164
control_cdf, control_cdf_mat = self._compute_cumulative_distribution(
165-
np.full(locations.shape, control_treatment_arm),
165+
control_treatment_arm,
166166
locations,
167167
self.confoundings,
168168
self.treatment_arms,
@@ -207,7 +207,7 @@ def _compute_ptes(
207207
"""Compute expected PTEs."""
208208
treatment_cumulative_pre, treatment_cdf_mat_pre = (
209209
self._compute_cumulative_distribution(
210-
np.full(locations.shape, target_treatment_arm),
210+
target_treatment_arm,
211211
locations,
212212
self.confoundings,
213213
self.treatment_arms,
@@ -216,7 +216,7 @@ def _compute_ptes(
216216
)
217217
treatment_cumulative_post, treatment_cdf_mat_post = (
218218
self._compute_cumulative_distribution(
219-
np.full(locations.shape, target_treatment_arm),
219+
target_treatment_arm,
220220
locations + width,
221221
self.confoundings,
222222
self.treatment_arms,
@@ -226,7 +226,7 @@ def _compute_ptes(
226226
treatment_pdf = treatment_cumulative_post - treatment_cumulative_pre
227227
control_cumulative_pre, control_cdf_mat_pre = (
228228
self._compute_cumulative_distribution(
229-
np.full(locations.shape, control_treatment_arm),
229+
control_treatment_arm,
230230
locations,
231231
self.confoundings,
232232
self.treatment_arms,
@@ -235,7 +235,7 @@ def _compute_ptes(
235235
)
236236
control_cumulative_post, control_cdf_mat_post = (
237237
self._compute_cumulative_distribution(
238-
np.full(locations.shape, control_treatment_arm),
238+
control_treatment_arm,
239239
locations + width,
240240
self.confoundings,
241241
self.treatment_arms,
@@ -291,7 +291,7 @@ def find_quantile(quantile, arm):
291291
while low <= high:
292292
mid = (low + high) // 2
293293
val, _ = self._compute_cumulative_distribution(
294-
np.full((1), arm),
294+
arm,
295295
np.full((1), locations[mid]),
296296
confoundings,
297297
treatment_arms,
@@ -339,11 +339,11 @@ def fit(
339339

340340
return self
341341

342-
def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarray:
342+
def predict(self, treatment_arm: int, locations: np.ndarray) -> np.ndarray:
343343
"""Compute cumulative distribution values.
344344
345345
Args:
346-
treatment_arms (np.ndarray): The index of the treatment arm.
346+
treatment_arm (int): The index of the treatment arm.
347347
outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
348348
349349
Returns:
@@ -354,15 +354,13 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
354354
"This estimator has not been trained yet. Please call fit first"
355355
)
356356

357-
unincluded_arms = set(treatment_arms) - set(self.treatment_arms)
358-
359-
if len(unincluded_arms) > 0:
357+
if treatment_arm not in self.treatment_arms:
360358
raise ValueError(
361-
f"This treatment_arms argument contains arms not included in the training data: {unincluded_arms}"
359+
f"This target treatment arm was not included in the training data: {treatment_arm}"
362360
)
363361

364362
return self._compute_cumulative_distribution(
365-
treatment_arms,
363+
treatment_arm,
366364
locations,
367365
self.confoundings,
368366
self.treatment_arms,
@@ -371,7 +369,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
371369

372370
def _compute_cumulative_distribution(
373371
self,
374-
target_treatment_arms: np.ndarray,
372+
target_treatment_arm: int,
375373
locations: np.ndarray,
376374
confoundings: np.ndarray,
377375
treatment_arms: np.ndarray,
@@ -396,7 +394,7 @@ def __init__(self):
396394

397395
def _compute_cumulative_distribution(
398396
self,
399-
target_treatment_arms: np.ndarray,
397+
target_treatment_arm: int,
400398
locations: np.ndarray,
401399
confoundings: np.ndarray,
402400
treatment_arms: np.ndarray,
@@ -405,7 +403,7 @@ def _compute_cumulative_distribution(
405403
"""Compute the cumulative distribution values.
406404
407405
Args:
408-
target_treatment_arms (np.ndarray): The index of the treatment arm.
406+
target_treatment_arm (int): The index of the treatment arm.
409407
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
410408
confoundings: (np.ndarray): An array of confounding variables in the observed data.
411409
treatment_arms (np.ndarray): An array of treatment arms in the observed data.
@@ -426,22 +424,23 @@ def _compute_cumulative_distribution(
426424
d_confounding[arm] = selected_confounding[sorted_indices]
427425
d_outcome[arm] = selected_outcome[sorted_indices]
428426
cumulative_distribution = np.zeros(locations.shape)
429-
for i, (outcome, arm) in enumerate(zip(locations, target_treatment_arms)):
427+
for i, outcome in enumerate(locations):
430428
cumulative_distribution[i] = (
431-
np.searchsorted(d_outcome[arm], outcome, side="right")
432-
) / d_outcome[arm].shape[0]
429+
np.searchsorted(d_outcome[target_treatment_arm], outcome, side="right")
430+
) / len(d_outcome[target_treatment_arm])
433431
return cumulative_distribution, np.zeros((n_obs, n_loc))
434432

435433

436434
class AdjustedDistributionEstimator(DistributionEstimatorBase):
437435
"""A class is for estimating the adjusted distribution function and computing the Distributional parameters based on the trained conditional estimator."""
438436

439-
def __init__(self, base_model, folds=3):
437+
def __init__(self, base_model, folds=3, is_multi_task=False):
440438
"""Initializes the AdjustedDistributionEstimator.
441439
442440
Args:
443441
base_model (scikit-learn estimator): The base model implementing used for conditional distribution function estimators. The model should implement fit(data, targets) and predict_proba(data).
444442
folds (int): The number of folds for cross-fitting.
443+
is_multi_task(bool): Whether to use multi-task learning. If True, your base model needs to support multi-task prediction (n_samples, n_features) -> (n_samples, n_targets).
445444
446445
Returns:
447446
AdjustedDistributionEstimator: An instance of the estimator.
@@ -454,11 +453,12 @@ def __init__(self, base_model, folds=3):
454453
)
455454
self.base_model = base_model
456455
self.folds = folds
456+
self.is_multi_task = is_multi_task
457457
super().__init__()
458458

459459
def _compute_cumulative_distribution(
460460
self,
461-
target_treatment_arms: np.ndarray,
461+
target_treatment_arm: int,
462462
locations: np.ndarray,
463463
confoundings: np.ndarray,
464464
treatment_arms: np.ndarray,
@@ -467,7 +467,7 @@ def _compute_cumulative_distribution(
467467
"""Compute the cumulative distribution values.
468468
469469
Args:
470-
target_treatment_arms (np.ndarray): The index of the treatment arm.
470+
target_treatment_arm (int): The index of the treatment arm.
471471
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
472472
confoundings: (np.ndarray): An array of confounding variables in the observed data.
473473
treatment_arm (np.ndarray): An array of treatment arms in the observed data.
@@ -476,43 +476,75 @@ def _compute_cumulative_distribution(
476476
Returns:
477477
np.ndarray: Estimated cumulative distribution values.
478478
"""
479-
n_obs = outcomes.shape[0]
479+
n_records = outcomes.shape[0]
480480
n_loc = locations.shape[0]
481-
cumulative_distribution = np.zeros(locations.shape)
482-
superset_prediction = np.zeros((n_obs, n_loc))
483-
for i, (location, arm) in enumerate(zip(locations, target_treatment_arms)):
484-
confounding_in_arm = confoundings[treatment_arms == arm]
485-
outcome_in_arm = outcomes[treatment_arms == arm]
486-
subset_prediction = np.zeros(outcome_in_arm.shape[0])
487-
binominal = (outcome_in_arm <= location) * 1
488-
cdf = binominal.mean()
481+
cumulative_distribution = np.zeros(n_loc)
482+
superset_prediction = np.zeros((n_records, n_loc))
483+
treatment_mask = treatment_arms == target_treatment_arm
484+
if self.is_multi_task:
485+
confounding_in_arm = confoundings[treatment_mask]
486+
n_records_in_arm = len(confounding_in_arm)
487+
outcome_in_arm = outcomes[treatment_mask] # (n_records)
488+
subset_prediction = np.zeros(
489+
(n_records_in_arm, n_loc)
490+
) # (n_records_in_arm, n_loc)
491+
binominal = (outcomes.reshape(-1, 1) <= locations) * 1 # (n_records, n_loc)
492+
cdf = binominal[treatment_mask].mean(axis=0) # (n_loc)
489493
for fold in range(self.folds):
490-
subset_mask = (
491-
np.arange(confounding_in_arm.shape[0]) % self.folds == fold
492-
)
493-
confounding_train = confounding_in_arm[~subset_mask]
494-
confounding_fit = confounding_in_arm[subset_mask]
494+
superset_mask = np.arange(n_records) % self.folds == fold
495+
subset_mask = superset_mask & treatment_mask
496+
subset_mask_inner = superset_mask[treatment_mask]
497+
confounding_train = confoundings[~subset_mask]
498+
confounding_fit = confoundings[subset_mask]
495499
binominal_train = binominal[~subset_mask]
496-
superset_mask = np.arange(self.outcomes.shape[0]) % self.folds == fold
497-
if np.unique(binominal_train).shape[0] == 1:
498-
subset_prediction[subset_mask] = binominal_train[0]
499-
superset_prediction[superset_mask, i] = binominal_train[0]
500-
continue
501500
model = deepcopy(self.base_model)
502501
model.fit(confounding_train, binominal_train)
503-
subset_prediction[subset_mask] = self._compute_model_prediction(
502+
subset_prediction[subset_mask_inner] = self._compute_model_prediction(
504503
model, confounding_fit
505504
)
506-
superset_prediction[superset_mask, i] = self._compute_model_prediction(
505+
superset_prediction[superset_mask] = self._compute_model_prediction(
507506
model, confoundings[superset_mask]
508507
)
509-
cumulative_distribution[i] = (
510-
cdf - subset_prediction.mean() + superset_prediction[:, i].mean()
511-
)
508+
cumulative_distribution = (
509+
cdf - subset_prediction.mean(axis=0) + superset_prediction.mean(axis=0)
510+
) # (n_loc)
511+
else:
512+
for i, location in enumerate(locations):
513+
confounding_in_arm = confoundings[treatment_mask]
514+
outcome_in_arm = outcomes[treatment_mask]
515+
subset_prediction = np.zeros(outcome_in_arm.shape[0])
516+
binominal = (outcomes <= location) * 1 # (n_records)
517+
cdf = binominal[treatment_mask].mean()
518+
for fold in range(self.folds):
519+
superset_mask = np.arange(n_records) % self.folds == fold
520+
subset_mask = superset_mask & treatment_mask
521+
subset_mask_inner = superset_mask[treatment_mask]
522+
confounding_train = confoundings[~subset_mask]
523+
confounding_fit = confoundings[subset_mask]
524+
binominal_train = binominal[~subset_mask]
525+
if len(np.unique(binominal_train)) == 1:
526+
subset_prediction[subset_mask_inner] = binominal_train[0]
527+
superset_prediction[superset_mask, i] = binominal_train[0]
528+
continue
529+
model = deepcopy(self.base_model)
530+
model.fit(confounding_train, binominal_train)
531+
subset_prediction[subset_mask_inner] = (
532+
self._compute_model_prediction(model, confounding_fit)
533+
)
534+
superset_prediction[superset_mask, i] = (
535+
self._compute_model_prediction(
536+
model, confoundings[superset_mask]
537+
)
538+
)
539+
cumulative_distribution[i] = (
540+
cdf - subset_prediction.mean() + superset_prediction[:, i].mean()
541+
)
512542
return cumulative_distribution, superset_prediction
513543

514544
def _compute_model_prediction(self, model, confoundings: np.ndarray) -> np.ndarray:
515545
if hasattr(model, "predict_proba"):
546+
if self.is_multi_task:
547+
return model.predict_proba(confoundings)
516548
return model.predict_proba(confoundings)[:, 1]
517549
else:
518550
return model.predict(confoundings)

0 commit comments

Comments
 (0)