Skip to content

Commit 9f4c389

Browse files
committed
Restructure tests
1 parent b8c076f commit 9f4c389

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

tests/test_wrappers.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from pathlib import Path
2-
from tempfile import TemporaryDirectory
32

43
import pytest
5-
import torch
6-
from torchvision import transforms
4+
import torch as th
75

86
from CollaborativeCoding import MetricWrapper, load_data, load_model
97

108

119
def test_load_model():
12-
import torch as th
13-
1410
image_shape = (1, 16, 16)
1511
num_classes = 4
1612

@@ -36,9 +32,6 @@ def test_load_model():
3632

3733

3834
def test_load_data():
39-
from tempfile import TemporaryDirectory
40-
41-
import torch as th
4235
from torchvision import transforms
4336

4437
dataset_names = [
@@ -56,21 +49,16 @@ def test_load_data():
5649
]
5750
)
5851

59-
with TemporaryDirectory() as tmppath:
60-
for name in dataset_names:
61-
dataset = load_data(
62-
name, train=False, data_dir=Path(tmppath), transform=trans
63-
)
52+
for name in dataset_names:
53+
dataset = load_data(name, train=False, data_dir=Path.cwd(), transform=trans)
6454

65-
im, _ = dataset.__getitem__(0)
55+
im, _ = dataset.__getitem__(0)
6656

67-
assert dataset.__len__() != 0
68-
assert type(im) == th.Tensor and len(im.size()) == 3
57+
assert dataset.__len__() != 0
58+
assert type(im) is th.Tensor and len(im.size()) == 3
6959

7060

7161
def test_load_metric():
72-
import torch as th
73-
7462
metrics = ("entropy", "f1", "recall", "precision", "accuracy")
7563

7664
class_sizes = [3, 6, 10]

0 commit comments

Comments
 (0)