Skip to content

Commit dee723f

Browse files
authored
Merge branch 'master' into master
2 parents 12c9247 + 2b274f5 commit dee723f

File tree

319 files changed

+37447
-1256
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

319 files changed

+37447
-1256
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.ipynb linguist-documentation

.github/workflows/check-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
run: |
1919
sudo apt install pandoc
2020
python -m pip install --upgrade pip
21-
pip install jinja2==3.0.3 sphinx numpydoc nbsphinx sphinx_gallery sphinx_rtd_theme ipython
21+
pip install jinja2==3.0.3 sphinx==4.4.0 numpydoc==1.2 nbsphinx==0.8.8 sphinx_gallery==0.10.1 sphinx_rtd_theme==1.0.0 ipython==8.0.1
2222
- name: Install adapt dependencies
2323
run: |
2424
python -m pip install --upgrade pip

.github/workflows/publish-doc-to-remote.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
run: |
1919
sudo apt install pandoc
2020
python -m pip install --upgrade pip
21-
pip install jinja2==3.0.3 sphinx numpydoc nbsphinx sphinx_gallery sphinx_rtd_theme ipython
21+
pip install jinja2==3.0.3 sphinx==4.4.0 numpydoc==1.2 nbsphinx==0.8.8 sphinx_gallery==0.10.1 sphinx_rtd_theme==1.0.0 ipython==8.0.1
2222
- name: Install adapt dependencies
2323
run: |
2424
python -m pip install --upgrade pip

.github/workflows/run-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ jobs:
1818
exclude:
1919
- os: windows-latest
2020
python-version: 3.9
21+
- os: ubuntu-latest
22+
python-version: 3.6
2123
runs-on: ${{ matrix.os }}
2224
steps:
2325
- uses: actions/checkout@v2

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ var/
1919
*.egg
2020
docs_build/
2121
docs/html/
22-
docs/doctrees/
22+
docs/doctrees/
23+
adapt/datasets.py
24+
datasets/

README.md

Lines changed: 201 additions & 49 deletions
Large diffs are not rendered by default.

adapt/base.py

Lines changed: 116 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from sklearn.exceptions import NotFittedError
1515
from tensorflow.keras import Model
1616
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
17+
try:
18+
from tensorflow.keras.optimizers.legacy import RMSprop
19+
except:
20+
from tensorflow.keras.optimizers import RMSprop
21+
1722

1823
from adapt.utils import (check_estimator,
1924
check_network,
@@ -282,8 +287,8 @@ def unsupervised_score(self, Xs, Xt):
282287
score : float
283288
Unsupervised score.
284289
"""
285-
Xs = check_array(np.array(Xs))
286-
Xt = check_array(np.array(Xt))
290+
Xs = check_array(Xs, accept_sparse=True)
291+
Xt = check_array(Xt, accept_sparse=True)
287292

288293
if hasattr(self, "transform"):
289294
args = [
@@ -306,13 +311,11 @@ def unsupervised_score(self, Xs, Xt):
306311

307312
set_random_seed(self.random_state)
308313
bootstrap_index = np.random.choice(
309-
len(Xs), size=len(Xs), replace=True, p=sample_weight)
314+
Xs.shape[0], size=Xs.shape[0], replace=True, p=sample_weight)
310315
Xs = Xs[bootstrap_index]
311316
else:
312317
raise ValueError("The Adapt model should implement"
313318
" a transform or predict_weights methods")
314-
Xs = np.array(Xs)
315-
Xt = np.array(Xt)
316319
return normalized_linear_discrepancy(Xs, Xt)
317320

318321

@@ -466,18 +469,27 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
466469
"""
467470
Xt, yt = self._get_target_data(Xt, yt)
468471
X, y = check_arrays(X, y)
472+
self.n_features_in_ = X.shape[1]
469473
if yt is not None:
470474
Xt, yt = check_arrays(Xt, yt)
471475
else:
472476
Xt = check_array(Xt, ensure_2d=True, allow_nd=True)
473477
set_random_seed(self.random_state)
478+
479+
self.n_features_in_ = X.shape[1]
474480

475481
if hasattr(self, "fit_weights"):
476482
if self.verbose:
477483
print("Fit weights...")
478-
self.weights_ = self.fit_weights(Xs=X, Xt=Xt,
479-
ys=y, yt=yt,
480-
domains=domains)
484+
out = self.fit_weights(Xs=X, Xt=Xt,
485+
ys=y, yt=yt,
486+
domains=domains)
487+
if isinstance(out, tuple):
488+
self.weights_ = out[0]
489+
X = out[1]
490+
y = out[2]
491+
else:
492+
self.weights_ = out
481493
if "sample_weight" in fit_params:
482494
fit_params["sample_weight"] *= self.weights_
483495
else:
@@ -534,7 +546,7 @@ def fit_estimator(self, X, y, sample_weight=None,
534546
-------
535547
estimator_ : fitted estimator
536548
"""
537-
X, y = check_arrays(X, y)
549+
X, y = check_arrays(X, y, accept_sparse=True)
538550
set_random_seed(random_state)
539551

540552
if (not warm_start) or (not hasattr(self, "estimator_")):
@@ -613,7 +625,7 @@ def predict_estimator(self, X, **predict_params):
613625
y_pred : array
614626
prediction of estimator.
615627
"""
616-
X = check_array(X, ensure_2d=True, allow_nd=True)
628+
X = check_array(X, ensure_2d=True, allow_nd=True, accept_sparse=True)
617629
predict_params = self._filter_params(self.estimator_.predict,
618630
predict_params)
619631
return self.estimator_.predict(X, **predict_params)
@@ -648,7 +660,7 @@ def predict(self, X, domain=None, **predict_params):
648660
y_pred : array
649661
prediction of the Adapt Model.
650662
"""
651-
X = check_array(X, ensure_2d=True, allow_nd=True)
663+
X = check_array(X, ensure_2d=True, allow_nd=True, accept_sparse=True)
652664
if hasattr(self, "transform"):
653665
if domain is None:
654666
domain = "tgt"
@@ -700,7 +712,7 @@ def score(self, X, y, sample_weight=None, domain=None):
700712
score : float
701713
estimator score.
702714
"""
703-
X, y = check_arrays(X, y)
715+
X, y = check_arrays(X, y, accept_sparse=True)
704716

705717
if domain is None:
706718
domain = "target"
@@ -788,7 +800,6 @@ def _get_legal_params(self, params):
788800

789801

790802
def __getstate__(self):
791-
print("getting")
792803
dict_ = {k: v for k, v in self.__dict__.items()}
793804
if "estimator_" in dict_:
794805
if isinstance(dict_["estimator_"], Model):
@@ -810,7 +821,6 @@ def __getstate__(self):
810821

811822

812823
def __setstate__(self, dict_):
813-
print("setting")
814824
if "estimator_" in dict_:
815825
if isinstance(dict_["estimator_"], dict):
816826
dict_["estimator_"] = self._from_config_keras_model(
@@ -960,9 +970,10 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
960970
epochs = fit_params.get("epochs", 1)
961971
batch_size = fit_params.pop("batch_size", 32)
962972
shuffle = fit_params.pop("shuffle", True)
973+
buffer_size = fit_params.pop("buffer_size", None)
963974
validation_data = fit_params.pop("validation_data", None)
964975
validation_split = fit_params.pop("validation_split", 0.)
965-
validation_batch_size = fit_params.pop("validation_batch_size", batch_size)
976+
validation_batch_size = fit_params.get("validation_batch_size", batch_size)
966977

967978
# 2. Prepare datasets
968979

@@ -1000,8 +1011,7 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
10001011
for dom in range(self.n_sources_))
10011012
)
10021013

1003-
dataset_src = tf.data.Dataset.zip((dataset_Xs, dataset_ys))
1004-
1014+
dataset_src = tf.data.Dataset.zip((dataset_Xs, dataset_ys))
10051015
else:
10061016
dataset_src = X
10071017

@@ -1031,47 +1041,62 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
10311041
self._initialize_networks()
10321042
if isinstance(Xt, tf.data.Dataset):
10331043
first_elem = next(iter(Xt))
1034-
if (not isinstance(first_elem, tuple) or
1035-
not len(first_elem)==2):
1036-
raise ValueError("When first argument is a dataset. "
1037-
"It should return (x, y) tuples.")
1044+
if not isinstance(first_elem, tuple):
1045+
shape = first_elem.shape
10381046
else:
10391047
shape = first_elem[0].shape
1048+
if self._check_for_batch(Xt):
1049+
shape = shape[1:]
10401050
else:
10411051
shape = Xt.shape[1:]
10421052
self._initialize_weights(shape)
10431053

1044-
# validation_data = self._check_validation_data(validation_data,
1045-
# validation_batch_size,
1046-
# shuffle)
1054+
1055+
# 3.5 Get datasets length
1056+
self.length_src_ = self._get_length_dataset(dataset_src, domain="src")
1057+
self.length_tgt_ = self._get_length_dataset(dataset_tgt, domain="tgt")
1058+
10471059

10481060
# 4. Prepare validation dataset
10491061
if validation_data is None and validation_split>0.:
10501062
if shuffle:
1051-
dataset_src = dataset_src.shuffle(buffer_size=1024)
1052-
frac = int(len(dataset_src)*validation_split)
1063+
dataset_src = dataset_src.shuffle(buffer_size=self.length_src_,
1064+
reshuffle_each_iteration=False)
1065+
frac = int(self.length_src_*validation_split)
10531066
validation_data = dataset_src.take(frac)
10541067
dataset_src = dataset_src.skip(frac)
1055-
validation_data = validation_data.batch(batch_size)
1068+
if not self._check_for_batch(validation_data):
1069+
validation_data = validation_data.batch(validation_batch_size)
1070+
1071+
if validation_data is not None:
1072+
if isinstance(validation_data, tf.data.Dataset):
1073+
if not self._check_for_batch(validation_data):
1074+
validation_data = validation_data.batch(validation_batch_size)
10561075

1076+
10571077
# 5. Set datasets
10581078
# Same length for src and tgt + complete last batch + shuffle
1059-
try:
1060-
max_size = max(len(dataset_src), len(dataset_tgt))
1061-
max_size = np.ceil(max_size / batch_size) * batch_size
1062-
repeat_src = np.ceil(max_size/len(dataset_src))
1063-
repeat_tgt = np.ceil(max_size/len(dataset_tgt))
1064-
1065-
dataset_src = dataset_src.repeat(repeat_src)
1066-
dataset_tgt = dataset_tgt.repeat(repeat_tgt)
1067-
1068-
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
1069-
except:
1070-
pass
1071-
10721079
if shuffle:
1073-
dataset_src = dataset_src.shuffle(buffer_size=1024)
1074-
dataset_tgt = dataset_tgt.shuffle(buffer_size=1024)
1080+
if buffer_size is None:
1081+
dataset_src = dataset_src.shuffle(buffer_size=self.length_src_,
1082+
reshuffle_each_iteration=True)
1083+
dataset_tgt = dataset_tgt.shuffle(buffer_size=self.length_tgt_,
1084+
reshuffle_each_iteration=True)
1085+
else:
1086+
dataset_src = dataset_src.shuffle(buffer_size=buffer_size,
1087+
reshuffle_each_iteration=True)
1088+
dataset_tgt = dataset_tgt.shuffle(buffer_size=buffer_size,
1089+
reshuffle_each_iteration=True)
1090+
1091+
max_size = max(self.length_src_, self.length_tgt_)
1092+
max_size = np.ceil(max_size / batch_size) * batch_size
1093+
repeat_src = np.ceil(max_size/self.length_src_)
1094+
repeat_tgt = np.ceil(max_size/self.length_tgt_)
1095+
1096+
dataset_src = dataset_src.repeat(repeat_src).take(max_size)
1097+
dataset_tgt = dataset_tgt.repeat(repeat_tgt).take(max_size)
1098+
1099+
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
10751100

10761101
# 5. Pretraining
10771102
if not hasattr(self, "pretrain_"):
@@ -1099,14 +1124,14 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
10991124
pre_verbose = prefit_params.pop("verbose", verbose)
11001125
pre_epochs = prefit_params.pop("epochs", epochs)
11011126
pre_batch_size = prefit_params.pop("batch_size", batch_size)
1102-
pre_shuffle = prefit_params.pop("shuffle", shuffle)
11031127
prefit_params.pop("validation_data", None)
1104-
prefit_params.pop("validation_split", None)
1105-
prefit_params.pop("validation_batch_size", None)
11061128

11071129
# !!! shuffle is already done
1108-
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(pre_batch_size)
1109-
1130+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt))
1131+
1132+
if not self._check_for_batch(dataset):
1133+
dataset = dataset.batch(pre_batch_size)
1134+
11101135
hist = super().fit(dataset, validation_data=validation_data,
11111136
epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
11121137

@@ -1123,7 +1148,10 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
11231148
self.history_ = {}
11241149

11251150
# .7 Training
1126-
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(batch_size)
1151+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt))
1152+
1153+
if not self._check_for_batch(dataset):
1154+
dataset = dataset.batch(batch_size)
11271155

11281156
self.pretrain_ = False
11291157

@@ -1259,7 +1287,8 @@ def compile(self,
12591287
if "_" in name:
12601288
new_name = ""
12611289
for split in name.split("_"):
1262-
new_name += split[0]
1290+
if len(split) > 0:
1291+
new_name += split[0]
12631292
name = new_name
12641293
else:
12651294
name = name[:3]
@@ -1284,7 +1313,7 @@ def compile(self,
12841313

12851314
if ((not "optimizer" in compile_params) or
12861315
(compile_params["optimizer"] is None)):
1287-
compile_params["optimizer"] = "rmsprop"
1316+
compile_params["optimizer"] = RMSprop()
12881317
else:
12891318
if optimizer is None:
12901319
if not isinstance(compile_params["optimizer"], str):
@@ -1331,7 +1360,8 @@ def train_step(self, data):
13311360
loss = tf.reduce_mean(loss)
13321361

13331362
# Run backwards pass.
1334-
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
1363+
gradients = tape.gradient(loss, self.trainable_variables)
1364+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
13351365
self.compiled_metrics.update_state(ys, y_pred)
13361366
# Collect metrics to return
13371367
return_metrics = {}
@@ -1573,6 +1603,37 @@ def _initialize_weights(self, shape_X):
15731603
X_enc = self.encoder_(np.zeros((1,) + shape_X))
15741604
if hasattr(self, "discriminator_"):
15751605
self.discriminator_(X_enc)
1606+
1607+
1608+
def _get_length_dataset(self, dataset, domain="src"):
1609+
try:
1610+
length = len(dataset)
1611+
except:
1612+
if self.verbose:
1613+
print("Computing %s dataset size..."%domain)
1614+
if not hasattr(self, "length_%s_"%domain):
1615+
length = 0
1616+
for _ in dataset:
1617+
length += 1
1618+
else:
1619+
length = getattr(self, "length_%s_"%domain)
1620+
if self.verbose:
1621+
print("Done!")
1622+
return length
1623+
1624+
1625+
def _check_for_batch(self, dataset):
1626+
if dataset.__class__.__name__ == "BatchDataset":
1627+
return True
1628+
if hasattr(dataset, "_input_dataset"):
1629+
return self._check_for_batch(dataset._input_dataset)
1630+
elif hasattr(dataset, "_datasets"):
1631+
checks = []
1632+
for data in dataset._datasets:
1633+
checks.append(self._check_for_batch(data))
1634+
return np.all(checks)
1635+
else:
1636+
return False
15761637

15771638

15781639
def _unpack_data(self, data):
@@ -1596,23 +1657,23 @@ def _get_disc_metrics(self, ys_disc, yt_disc):
15961657

15971658
def _initialize_networks(self):
15981659
if self.encoder is None:
1599-
self.encoder_ = get_default_encoder(name="encoder")
1660+
self.encoder_ = get_default_encoder(name="encoder", state=self.random_state)
16001661
else:
16011662
self.encoder_ = check_network(self.encoder,
16021663
copy=self.copy,
16031664
name="encoder")
16041665
if self.task is None:
1605-
self.task_ = get_default_task(name="task")
1666+
self.task_ = get_default_task(name="task", state=self.random_state)
16061667
else:
16071668
self.task_ = check_network(self.task,
16081669
copy=self.copy,
16091670
name="task")
16101671
if self.discriminator is None:
1611-
self.discriminator_ = get_default_discriminator(name="discriminator")
1672+
self.discriminator_ = get_default_discriminator(name="discriminator", state=self.random_state)
16121673
else:
16131674
self.discriminator_ = check_network(self.discriminator,
16141675
copy=self.copy,
16151676
name="discriminator")
16161677

16171678
def _initialize_pretain_networks(self):
1618-
pass
1679+
pass

0 commit comments

Comments
 (0)