|
| 1 | +""" |
| 2 | +Frustratingly Easy Domain Adaptation module. |
| 3 | +""" |
| 4 | + |
| 5 | +import warnings |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from sklearn.utils import check_array |
| 9 | +from sklearn.exceptions import NotFittedError |
| 10 | + |
| 11 | +from adapt.base import BaseAdaptEstimator, make_insert_doc |
| 12 | +from adapt.utils import check_arrays, check_estimator |
| 13 | + |
| 14 | + |
| 15 | +@make_insert_doc(supervised=True) |
| 16 | +class PRED(BaseAdaptEstimator): |
| 17 | + """ |
| 18 | + PRED: Feature Augmentation with SrcOnly Prediction |
| 19 | +
|
| 20 | + PRED uses the output of a source pretrain model as a feature in |
| 21 | + the target model. Specifically, PRED first trains a |
| 22 | + SrcOnly model. Then it runs the SrcOnly model on the target data. |
| 23 | + It uses the predictions made by the SrcOnly model as additional features |
| 24 | + and trains a second model on the target data, augmented with this new feature. |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + pretrain : bool (default=True) |
| 29 | + Weither to pretrain the estimator on |
| 30 | + source or not. If False, `estimator` |
| 31 | + should be already fitted. |
| 32 | +
|
| 33 | + Attributes |
| 34 | + ---------- |
| 35 | + estimator_src_ : object |
| 36 | + Fitted source estimator. |
| 37 | + |
| 38 | + estimator_ : object |
| 39 | + Fitted estimator. |
| 40 | + |
| 41 | + See also |
| 42 | + -------- |
| 43 | + FA |
| 44 | + |
| 45 | + Examples |
| 46 | + -------- |
| 47 | + >>> from sklearn.linear_model import RidgeClassifier |
| 48 | + >>> from adapt.utils import make_classification_da |
| 49 | + >>> from adapt.feature_based import PRED |
| 50 | + >>> Xs, ys, Xt, yt = make_classification_da() |
| 51 | + >>> model = PRED(RidgeClassifier(0.), Xt=Xt[[1, -1, -2]], yt=yt[[1, -1, -2]], |
| 52 | + ... pretrain=True, verbose=0, random_state=0) |
| 53 | + >>> model.fit(Xs, ys) |
| 54 | + >>> model.score(Xt, yt) |
| 55 | + 0.77 |
| 56 | +
|
| 57 | + References |
| 58 | + ---------- |
| 59 | + .. [1] `[1] <https://arxiv.org/pdf/0907.1815\ |
| 60 | +.pdf>`_ Daume III, H. "Frustratingly easy domain adaptation". In ACL, 2007. |
| 61 | + """ |
| 62 | + def __init__(self, |
| 63 | + estimator=None, |
| 64 | + Xt=None, |
| 65 | + yt=None, |
| 66 | + copy=True, |
| 67 | + pretrain=True, |
| 68 | + verbose=1, |
| 69 | + random_state=None, |
| 70 | + **params): |
| 71 | + |
| 72 | + names = self._get_param_names() |
| 73 | + kwargs = {k: v for k, v in locals().items() if k in names} |
| 74 | + kwargs.update(params) |
| 75 | + super().__init__(**kwargs) |
| 76 | + |
| 77 | + |
| 78 | + def fit_transform(self, Xs, Xt, ys, yt, domains=None, **kwargs): |
| 79 | + """ |
| 80 | + Fit embeddings. |
| 81 | + |
| 82 | + Parameters |
| 83 | + ---------- |
| 84 | + Xs : array |
| 85 | + Source input data. |
| 86 | + |
| 87 | + Xt : array |
| 88 | + Target input data. |
| 89 | + |
| 90 | + ys : array |
| 91 | + Source output data. |
| 92 | + |
| 93 | + yt : array |
| 94 | + Target output data. |
| 95 | + |
| 96 | + kwargs : key, value argument |
| 97 | + Not used, present here for adapt consistency. |
| 98 | + |
| 99 | + Returns |
| 100 | + ------- |
| 101 | + Xt_aug, yt : augmented input and output target data |
| 102 | + """ |
| 103 | + Xs, ys = check_arrays(Xs, ys) |
| 104 | + Xt, yt = check_arrays(Xt, yt) |
| 105 | + |
| 106 | + self.estimators_ = [] |
| 107 | + |
| 108 | + if self.pretrain: |
| 109 | + estimator = self.fit_estimator(Xs, ys, |
| 110 | + warm_start=False, |
| 111 | + random_state=self.random_state) |
| 112 | + self.estimator_src_ = estimator |
| 113 | + del self.estimator_ |
| 114 | + else: |
| 115 | + self.estimator_src_ = check_estimator(self.estimator, |
| 116 | + copy=self.copy, |
| 117 | + force_copy=True) |
| 118 | + |
| 119 | + yt_pred = self.estimator_src_.predict(Xt) |
| 120 | + |
| 121 | + if len(yt_pred.shape) < 2: |
| 122 | + yt_pred = yt_pred.reshape(-1, 1) |
| 123 | + |
| 124 | + X = np.concatenate((Xt, yt_pred), axis=-1) |
| 125 | + y = yt |
| 126 | + return X, y |
| 127 | + |
| 128 | + |
| 129 | + def transform(self, X, domain="tgt"): |
| 130 | + """ |
| 131 | + Return augmented features for X. |
| 132 | + |
| 133 | + If `domain="tgt"`, the prediction of the source model on `X` |
| 134 | + are added to `X`. |
| 135 | + |
| 136 | + If `domain="src"`, `X` is returned. |
| 137 | + |
| 138 | + Parameters |
| 139 | + ---------- |
| 140 | + X : array |
| 141 | + Input data. |
| 142 | +
|
| 143 | + domain : str (default="tgt") |
| 144 | + Choose between ``"source", "src"`` and |
| 145 | + ``"target", "tgt"`` feature augmentation. |
| 146 | +
|
| 147 | + Returns |
| 148 | + ------- |
| 149 | + X_emb : array |
| 150 | + Embeddings of X. |
| 151 | + """ |
| 152 | + X = check_array(X, allow_nd=True) |
| 153 | + |
| 154 | + if domain in ["tgt", "target"]: |
| 155 | + y_pred = self.estimator_src_.predict(X) |
| 156 | + if len(y_pred.shape) < 2: |
| 157 | + y_pred = y_pred.reshape(-1, 1) |
| 158 | + X_emb = np.concatenate((X, y_pred), axis=-1) |
| 159 | + elif domain in ["src", "source"]: |
| 160 | + X_emb = X |
| 161 | + else: |
| 162 | + raise ValueError("`domain `argument " |
| 163 | + "should be `tgt` or `src`, " |
| 164 | + "got, %s"%domain) |
| 165 | + return X_emb |
| 166 | + |
| 167 | + |
| 168 | + def predict(self, X, domain=None, **predict_params): |
| 169 | + """ |
| 170 | + Return estimator predictions after |
| 171 | + adaptation. |
| 172 | + |
| 173 | + If `domain="tgt"`, the input feature ``X`` are first transformed. |
| 174 | + Then the ``predict`` method of the fitted estimator |
| 175 | + ``estimator_`` is applied on the transformed ``X``. |
| 176 | + |
| 177 | + If `domain="src"`, ``estimator_src_`` is applied direclty |
| 178 | + on ``X``. |
| 179 | + |
| 180 | + Parameters |
| 181 | + ---------- |
| 182 | + X : array |
| 183 | + input data |
| 184 | + |
| 185 | + domain : str (default=None) |
| 186 | + For antisymetric feature-based method, |
| 187 | + different transformation of the input X |
| 188 | + are applied for different domains. The domain |
| 189 | + should then be specified between "src" and "tgt". |
| 190 | + If ``None`` the default transformation is the |
| 191 | + target one. |
| 192 | + |
| 193 | + Returns |
| 194 | + ------- |
| 195 | + y_pred : array |
| 196 | + prediction of the Adapt Model. |
| 197 | + """ |
| 198 | + X = check_array(X, ensure_2d=True, allow_nd=True, accept_sparse=True) |
| 199 | + if domain is None: |
| 200 | + domain = "tgt" |
| 201 | + X = self.transform(X, domain=domain) |
| 202 | + |
| 203 | + if domain in ["tgt", "target"]: |
| 204 | + return self.estimator_.predict(X, **predict_params) |
| 205 | + else: |
| 206 | + return self.estimator_src_.predict(X, **predict_params) |
| 207 | + |
| 208 | + |
| 209 | + def score(self, X, y, sample_weight=None, domain=None): |
| 210 | + """ |
| 211 | + Return the estimator score. |
| 212 | + |
| 213 | + If `domain="tgt"`, the input feature ``X`` are first transformed. |
| 214 | + Then the ``score`` method of the fitted estimator |
| 215 | + ``estimator_`` is applied on the transformed ``X``. |
| 216 | + |
| 217 | + If `domain="src"`, ``estimator_src_`` is applied direclty |
| 218 | + on ``X``. |
| 219 | + |
| 220 | + Parameters |
| 221 | + ---------- |
| 222 | + X : array |
| 223 | + input data |
| 224 | + |
| 225 | + y : array |
| 226 | + output data |
| 227 | + |
| 228 | + sample_weight : array (default=None) |
| 229 | + Sample weights |
| 230 | + |
| 231 | + domain : str (default=None) |
| 232 | + This parameter specifies for antisymetric |
| 233 | + feature-based method which transformation |
| 234 | + will be applied between "source" and "target". |
| 235 | + If ``None`` the transformation by default is |
| 236 | + the target one. |
| 237 | + |
| 238 | + Returns |
| 239 | + ------- |
| 240 | + score : float |
| 241 | + estimator score. |
| 242 | + """ |
| 243 | + X, y = check_arrays(X, y, accept_sparse=True) |
| 244 | + |
| 245 | + if domain is None: |
| 246 | + domain = "tgt" |
| 247 | + X = self.transform(X, domain=domain) |
| 248 | + |
| 249 | + if domain in ["tgt", "target"]: |
| 250 | + estimator = self.estimator_ |
| 251 | + else: |
| 252 | + estimator = self.estimator_src_ |
| 253 | + |
| 254 | + if hasattr(estimator, "score"): |
| 255 | + score = estimator.score(X, y, sample_weight) |
| 256 | + elif hasattr(estimator, "evaluate"): |
| 257 | + if np.prod(X.shape) <= 10**8: |
| 258 | + score = estimator.evaluate( |
| 259 | + X, y, |
| 260 | + sample_weight=sample_weight, |
| 261 | + batch_size=len(X) |
| 262 | + ) |
| 263 | + else: |
| 264 | + score = estimator.evaluate( |
| 265 | + X, y, |
| 266 | + sample_weight=sample_weight |
| 267 | + ) |
| 268 | + if isinstance(score, (tuple, list)): |
| 269 | + score = score[0] |
| 270 | + else: |
| 271 | + raise ValueError("Estimator does not implement" |
| 272 | + " score or evaluate method") |
| 273 | + return score |
0 commit comments