Skip to content

Commit ba0109f

Browse files
authored
🐛 Fix in test_arch_mapde and test_arch_sccnn (#911)
- If cuda is available model should be moved to cuda otherwise tests will fail as test data is moved to cuda.
1 parent 264b079 commit ba0109f

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

tests/models/test_arch_mapde.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
ON_GPU = toolbox_env.has_gpu()
1515

1616

17-
def _load_mapde(name: str) -> torch.nn.Module:
17+
def _load_mapde(name: str) -> MapDe:
1818
"""Loads MapDe model with specified weights."""
1919
model = MapDe()
2020
weights_path = fetch_pretrained_weights(name)
2121
map_location = select_device(on_gpu=ON_GPU)
2222
pretrained = torch.load(weights_path, map_location=map_location)
2323
model.load_state_dict(pretrained)
24-
24+
model.to(map_location)
2525
return model
2626

2727

@@ -45,7 +45,6 @@ def test_functionality(remote_sample: Callable) -> None:
4545
model = _load_mapde(name="mapde-conic")
4646
patch = model.preproc(patch)
4747
batch = torch.from_numpy(patch)[None]
48-
model = model.to()
4948
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
5049
output = model.postproc(output[0])
5150
assert np.all(output[0:2] == [[19, 171], [53, 89]])

tests/models/test_arch_sccnn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
from tiatoolbox.wsicore.wsireader import WSIReader
1313

1414

15-
def _load_sccnn(name: str) -> torch.nn.Module:
15+
def _load_sccnn(name: str) -> SCCNN:
1616
"""Loads SCCNN model with specified weights."""
1717
model = SCCNN()
1818
weights_path = fetch_pretrained_weights(name)
1919
map_location = select_device(on_gpu=env_detection.has_gpu())
2020
pretrained = torch.load(weights_path, map_location=map_location)
2121
model.load_state_dict(pretrained)
22-
22+
model.to(map_location)
2323
return model
2424

2525

@@ -48,7 +48,6 @@ def test_functionality(remote_sample: Callable) -> None:
4848
)
4949
output = model.postproc(output[0])
5050
assert np.all(output == [[8, 7]])
51-
5251
model = _load_sccnn(name="sccnn-conic")
5352
output = model.infer_batch(
5453
model,

0 commit comments

Comments
 (0)