|
4 | 4 | from adapt.utils import make_classification_da |
5 | 5 | from adapt.feature_based import CCSA |
6 | 6 | from tensorflow.keras.initializers import GlorotUniform |
| 7 | +try: |
| 8 | + from tensorflow.keras.optimizers.legacy import Adam |
| 9 | +except: |
| 10 | + from tensorflow.keras.optimizers import Adam |
7 | 11 |
|
8 | 12 | np.random.seed(0) |
9 | 13 | tf.random.set_seed(0) |
|
18 | 22 |
|
19 | 23 | def test_ccsa(): |
20 | 24 | ccsa = CCSA(task=task, loss="categorical_crossentropy", |
21 | | - optimizer="adam", metrics=["acc"], gamma=0.1, random_state=0) |
| 25 | + optimizer=Adam(), metrics=["acc"], gamma=0.1, random_state=0) |
22 | 26 | ccsa.fit(Xs, tf.one_hot(ys, 2).numpy(), Xt=Xt[ind], |
23 | 27 | yt=tf.one_hot(yt, 2).numpy()[ind], epochs=100, verbose=0) |
24 | 28 | assert np.mean(ccsa.predict(Xt).argmax(1) == yt) > 0.8 |
25 | 29 |
|
26 | 30 | ccsa = CCSA(task=task, loss="categorical_crossentropy", |
27 | | - optimizer="adam", metrics=["acc"], gamma=1., random_state=0) |
| 31 | + optimizer=Adam(), metrics=["acc"], gamma=1., random_state=0) |
28 | 32 | ccsa.fit(Xs, tf.one_hot(ys, 2).numpy(), Xt=Xt[ind], |
29 | 33 | yt=tf.one_hot(yt, 2).numpy()[ind], epochs=100, verbose=0) |
30 | 34 |
|
|
0 commit comments