Skip to content

Commit 3b2713c

Browse files
Merge pull request #67 from KevinMusgrave/dev
v0.0.72
2 parents fed806d + e690f79 commit 3b2713c

File tree

8 files changed

+73
-31
lines changed

8 files changed

+73
-31
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ examples/**/officehome_resized.tar.gz
1717
examples/**/officehome
1818
examples/**/saved_models
1919
examples/**/lightning_logs
20-
zzz_pytorch_adapt_dataset_test_folder
20+
zzz_pytorch_adapt_dataset_test_folder
21+
.env

build_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
./format_code.sh
2-
python -m unittest discover && \
2+
RUN_DATASET_TESTS=true python -m unittest discover && \
33
rm -rfv build/ && \
44
rm -rfv dist/ && \
55
rm -rfv src/pytorch_adapt.egg-info/ && \

src/pytorch_adapt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.71"
1+
__version__ = "0.0.72"

src/pytorch_adapt/datasets/getters.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get_datasets(
2424
folder,
2525
download=False,
2626
return_target_with_labels=False,
27+
supervised=False,
2728
transform_getter=None,
2829
):
2930
def getter(domains, train, is_training):
@@ -47,14 +48,19 @@ def getter(domains, train, is_training):
4748
output["src_train"] = SourceDataset(getter(src_domains, True, False))
4849
output["src_val"] = SourceDataset(getter(src_domains, False, False))
4950
if target_domains:
50-
output["target_train"] = TargetDataset(getter(target_domains, True, False))
51-
output["target_val"] = TargetDataset(getter(target_domains, False, False))
51+
output["target_train"] = TargetDataset(
52+
getter(target_domains, True, False), supervised=supervised
53+
)
54+
output["target_val"] = TargetDataset(
55+
getter(target_domains, False, False), supervised=supervised
56+
)
57+
# For academic setting: unsupervised learning w/ seperate target datasets that have gt lables for eval.
5258
if return_target_with_labels:
53-
output["target_train_with_labels"] = SourceDataset(
54-
getter(target_domains, True, False), domain=1
59+
output["target_train_with_labels"] = TargetDataset(
60+
getter(target_domains, True, False), domain=1, supervised=True
5561
)
56-
output["target_val_with_labels"] = SourceDataset(
57-
getter(target_domains, False, False), domain=1
62+
output["target_val_with_labels"] = TargetDataset(
63+
getter(target_domains, False, False), domain=1, supervised=True
5864
)
5965
if src_domains and target_domains:
6066
output["train"] = CombinedSourceAndTargetDataset(

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
TEST_FOLDER = "zzz_pytorch_adapt_test_folder"
3030
DATASET_FOLDER = "zzz_pytorch_adapt_dataset_test_folder"
3131
RUN_DATASET_TESTS = os.environ.get("RUN_DATASET_TESTS", False)
32+
RUN_DOMAINNET_DATASET_TESTS = os.environ.get("RUN_DOMAINNET_DATASET_TESTS", False)
3233

3334
TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]
3435
TEST_DEVICE = torch.device(device_from_environ)

tests/datasets/test_domainnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44

55
from pytorch_adapt.datasets import DomainNet, DomainNet126, DomainNet126Full
66

7-
from .. import DATASET_FOLDER, RUN_DATASET_TESTS
7+
from .. import DATASET_FOLDER, RUN_DOMAINNET_DATASET_TESTS
88
from .utils import (
99
check_full,
1010
check_train_test_disjoint,
1111
check_train_test_matches_full,
1212
loop_through_dataset,
13-
skip_reason,
13+
skip_reason_domainnet,
1414
)
1515

1616

1717
class TestDomainNet(unittest.TestCase):
18-
@unittest.skipIf(not RUN_DATASET_TESTS, skip_reason)
18+
@unittest.skipIf(not RUN_DOMAINNET_DATASET_TESTS, skip_reason_domainnet)
1919
def test_domainnet(self):
2020
transform = torch_transforms.Compose(
2121
[
@@ -39,7 +39,7 @@ def test_domainnet(self):
3939
self.assertTrue(len(dataset) == length)
4040
loop_through_dataset(dataset)
4141

42-
@unittest.skipIf(not RUN_DATASET_TESTS, skip_reason)
42+
@unittest.skipIf(not RUN_DOMAINNET_DATASET_TESTS, skip_reason_domainnet)
4343
def test_domainnet126(self):
4444
check_train_test_matches_full(
4545
self,
@@ -50,7 +50,7 @@ def test_domainnet126(self):
5050
DATASET_FOLDER,
5151
)
5252

53-
@unittest.skipIf(not RUN_DATASET_TESTS, skip_reason)
53+
@unittest.skipIf(not RUN_DOMAINNET_DATASET_TESTS, skip_reason_domainnet)
5454
def test_domainnet126_full(self):
5555
check_full(
5656
self,

tests/datasets/test_getters.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,37 @@
1818

1919

2020
class TestGetters(unittest.TestCase):
21-
def helper(self, datasets, src_class, target_class, sizes):
21+
def helper(
22+
self,
23+
datasets,
24+
src_class,
25+
target_class,
26+
sizes,
27+
target_with_labels=False,
28+
supervised=False,
29+
):
2230
for k in ["src_train", "src_val"]:
2331
self.assertTrue(isinstance(datasets[k].dataset.datasets[0], src_class))
32+
self.assertTrue(isinstance(datasets[k], SourceDataset))
2433
self.assertTrue(len(datasets[k]) == sizes[k])
25-
for k in ["target_train", "target_val"]:
34+
35+
target_splits = ["target_train", "target_val"]
36+
if target_with_labels:
37+
target_splits += ["target_train_with_labels", "target_val_with_labels"]
38+
target_sizes = {k for k in sizes.keys() if k.startswith("target")}
39+
40+
self.assertTrue(set(target_splits) == target_sizes)
41+
for k in target_splits:
2642
self.assertTrue(isinstance(datasets[k].dataset.datasets[0], target_class))
43+
self.assertTrue(isinstance(datasets[k], TargetDataset))
2744
self.assertTrue(len(datasets[k]) == sizes[k])
45+
if supervised:
46+
# target_train and target_val will be supervised
47+
# and if target_with_labels is true, the with_labels will also be supervised
48+
self.assertTrue(datasets[k].supervised)
49+
else:
50+
# otherwise, only the ones with labels are supervised
51+
self.assertTrue(datasets[k].supervised == k.endswith("with_labels"))
2852

2953
@unittest.skipIf(not RUN_DATASET_TESTS, skip_reason)
3054
def test_empty_array(self):
@@ -41,21 +65,30 @@ def test_empty_array(self):
4165
self.assertTrue(isinstance(datasets["train"], TargetDataset))
4266

4367
@unittest.skipIf(not RUN_DATASET_TESTS, skip_reason)
44-
def test_get_mnist_mnistm(self):
45-
datasets = get_mnist_mnistm(
46-
["mnist"], ["mnistm"], folder=DATASET_FOLDER, download=True
47-
)
48-
self.helper(
49-
datasets,
50-
MNIST,
51-
MNISTM,
52-
{
53-
"src_train": 60000,
54-
"src_val": 10000,
55-
"target_train": 59001,
56-
"target_val": 9001,
57-
},
58-
)
68+
def test_get_mnist_mnistm_return_targets(self):
69+
for target_with_labels in [False, True]:
70+
for supervised in [False, True]:
71+
datasets = get_mnist_mnistm(
72+
["mnist"],
73+
["mnistm"],
74+
folder=DATASET_FOLDER,
75+
download=True,
76+
return_target_with_labels=target_with_labels,
77+
supervised=supervised,
78+
)
79+
len_dict = {
80+
"src_train": 60000,
81+
"src_val": 10000,
82+
"target_train": 59001,
83+
"target_val": 9001,
84+
}
85+
if target_with_labels:
86+
len_dict["target_train_with_labels"] = len_dict["target_train"]
87+
len_dict["target_val_with_labels"] = len_dict["target_val"]
88+
89+
self.helper(
90+
datasets, MNIST, MNISTM, len_dict, target_with_labels, supervised
91+
)
5992

6093
@unittest.skipIf(not RUN_DATASET_TESTS, skip_reason)
6194
def test_officehome(self):

tests/datasets/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torchvision import transforms as torch_transforms
77

88
skip_reason = "RUN_DATASET_TESTS is False"
9+
skip_reason_domainnet = "RUN_DOMAINNET_DATASET_TESTS is False"
910

1011

1112
def simple_transform():

0 commit comments

Comments
 (0)