Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ ENV/

# vim/vi generated
*.swp

# output zarr generated
*.zarr
20 changes: 18 additions & 2 deletions tests/models/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import pytest
import torch
import torchvision.models as torch_models
from torch import nn

from tiatoolbox import rcParam
from tiatoolbox import rcParam, utils
from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.models.models_abc import ModelABC, model_to
from tiatoolbox.utils import env_detection as toolbox_env

if TYPE_CHECKING:
Expand Down Expand Up @@ -149,3 +150,18 @@ def test_model_abc() -> None:
weights_path = fetch_pretrained_weights("alexnet-kather100k")
with pytest.raises(RuntimeError, match=r".*loading state_dict*"):
_ = model.load_weights_from_file(weights_path)


def test_model_to() -> None:
"""Test for placing model on device."""
# Test on GPU
# no GPU on GitHub Actions so this will crash
if not utils.env_detection.has_gpu():
model = torch_models.resnet18()
with pytest.raises((AssertionError, RuntimeError)):
_ = model_to(device="cuda", model=model)

# Test on CPU
model = torch_models.resnet18()
model = model_to(device="cpu", model=model)
assert isinstance(model, nn.Module)
4 changes: 2 additions & 2 deletions tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None:
model = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
model = model.to(select_device(on_gpu=ON_GPU))
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
model = model.to()
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])
2 changes: 1 addition & 1 deletion tests/models/test_arch_micronet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_functionality(
model = model.to(map_location)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=map_location)
output, _ = model.postproc(output[0])
assert np.max(np.unique(output)) == 46

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_arch_nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tiatoolbox.models import NuClick
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device

ON_GPU = False

Expand Down Expand Up @@ -53,7 +54,7 @@ def test_functional_nuclick(
model = NuClick(num_input_channels=5, num_output_channels=1)
pretrained = torch.load(weights_path, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
postproc_masks = model.postproc(
output,
do_reconstruction=True,
Expand Down
17 changes: 13 additions & 4 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import numpy as np
import torch

from tiatoolbox import utils
from tiatoolbox.models import SCCNN
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import env_detection
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader


def _load_sccnn(name: str) -> torch.nn.Module:
"""Loads SCCNN model with specified weights."""
model = SCCNN()
weights_path = fetch_pretrained_weights(name)
map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu())
map_location = select_device(on_gpu=env_detection.has_gpu())
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

Expand All @@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None:
)
batch = torch.from_numpy(patch)[None]
model = _load_sccnn(name="sccnn-crchisto")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
assert np.all(output == [[8, 7]])

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
assert np.all(output == [[7, 8]])
5 changes: 3 additions & 2 deletions tests/models/test_arch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.unet import UNetModel
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = False
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3])
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
_ = output[0]

# run untrained network to test for architecture
Expand All @@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None:
encoder_levels=[32, 64],
skip_type="concat",
)
_ = model.infer_batch(model, batch, on_gpu=ON_GPU)
_ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
11 changes: 6 additions & 5 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch

from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
from tiatoolbox.utils.misc import model_to
from tiatoolbox.models.models_abc import model_to

ON_GPU = False
RNG = np.random.default_rng() # Numpy Random Generator
device = "cuda" if ON_GPU else "cpu"


def test_functional() -> None:
Expand Down Expand Up @@ -43,8 +44,8 @@ def test_functional() -> None:
try:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand All @@ -70,8 +71,8 @@ def test_timm_functional() -> None:
try:
for backbone in backbones:
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
IOSegmentorConfig,
)
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
Expand All @@ -35,7 +36,7 @@ def test_engine(remote_sample: Callable, tmp_path: Path) -> None:
output_list = extractor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -82,7 +83,7 @@ def test_full_inference(
[mini_wsi_svs],
mode="wsi",
ioconfig=ioconfig,
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ResidualBlock,
TFSamepaddingLayer,
)
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader


Expand All @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_fast-pannuke")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_fast-monusac")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_original-consep")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_original-kumar")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_hovernetplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tiatoolbox.models import HoVerNetPlus
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device
from tiatoolbox.utils.transforms import imresize


Expand All @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernetplus-oed")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches."
output = [v[0] for v in output]
output = model.postproc(output)
Expand Down
23 changes: 12 additions & 11 deletions tests/models/test_multi_task_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils import imwrite
from tiatoolbox.utils.metrics import f1_detection
from tiatoolbox.utils.misc import select_device

ON_GPU = toolbox_env.has_gpu()
BATCH_SIZE = 1 if not ON_GPU else 8 # 16
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand All @@ -83,7 +84,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -117,7 +118,7 @@ def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) ->
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -195,7 +196,7 @@ def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
masks=[sample_wsi_msk],
mode="wsi",
ioconfig=ioconfig,
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -230,7 +231,7 @@ def test_functionality_process_instance_predictions(
output = semantic_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -268,7 +269,7 @@ def test_empty_image(tmp_path: Path) -> None:
_ = multi_segmentor.predict(
[sample_patch_path],
mode="tile",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand All @@ -284,7 +285,7 @@ def test_empty_image(tmp_path: Path) -> None:
_ = multi_segmentor.predict(
[sample_patch_path],
mode="tile",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -312,7 +313,7 @@ def test_empty_image(tmp_path: Path) -> None:
_ = multi_segmentor.predict(
[sample_patch_path],
mode="tile",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
ioconfig=bcc_wsi_ioconfig,
Expand Down Expand Up @@ -361,7 +362,7 @@ def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
ioconfig=bcc_wsi_ioconfig,
Expand Down Expand Up @@ -413,7 +414,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
masks=[sample_wsi_msk],
mode="wsi",
ioconfig=ioconfig,
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Loading
Loading