Skip to content

Commit 819e138

Browse files
committed
🔀 Undo unwanted changes during merge.
1 parent 0d68ad1 commit 819e138

File tree

4 files changed

+20
-27
lines changed

4 files changed

+20
-27
lines changed

tests/models/test_models_abc.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
import pytest
88
import torch
9+
import torchvision.models as torch_models
910
from torch import nn
1011

11-
import tiatoolbox.models
1212
from tiatoolbox import rcParam, utils
1313
from tiatoolbox.models.architecture import (
1414
fetch_pretrained_weights,
1515
get_pretrained_model,
1616
)
17-
from tiatoolbox.models.models_abc import ModelABC
17+
from tiatoolbox.models.models_abc import ModelABC, model_to
1818
from tiatoolbox.utils import env_detection as toolbox_env
1919

2020
if TYPE_CHECKING:
@@ -154,19 +154,16 @@ def test_model_abc() -> None:
154154

155155
def test_model_to() -> None:
156156
"""Test for placing model on device."""
157-
import torchvision.models as torch_models
158-
from torch import nn
159-
160157
# Test on GPU
161-
# no GPU on Travis so this will crash
158+
# no GPU on GitHub Actions so this will crash
162159
if not utils.env_detection.has_gpu():
163160
model = torch_models.resnet18()
164161
with pytest.raises((AssertionError, RuntimeError)):
165-
_ = tiatoolbox.models.models_abc.model_to(device="cuda", model=model)
162+
_ = model_to(device="cuda", model=model)
166163

167164
# Test on CPU
168165
model = torch_models.resnet18()
169-
model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model)
166+
model = model_to(device="cpu", model=model)
170167
assert isinstance(model, nn.Module)
171168

172169

tests/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,17 @@ def test_patch_pred_store() -> None:
16691669
with pytest.raises(ValueError, match="coordinates"):
16701670
misc.dict_to_store(patch_output, (1.0, 1.0))
16711671

1672+
patch_output = {
1673+
"predictions": [1, 0, 1],
1674+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1675+
"other": "other",
1676+
}
1677+
1678+
store = misc.dict_to_store(patch_output, (1.0, 1.0))
1679+
1680+
# Check that it is an SQLiteStore containing the expected annotations
1681+
assert isinstance(store, SQLiteStore)
1682+
16721683

16731684
def test_patch_pred_store_cdict() -> None:
16741685
"""Test patch_pred_store with a class dict."""

tiatoolbox/cli/common.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ def cli_output_type(
9494
input_type: click.Choice | None = None,
9595
) -> Callable:
9696
"""Enables --file-types option for cli."""
97-
if input_type is None:
98-
input_type = click.Choice(["zarr", "AnnotationStore"], case_sensitive=False)
97+
click_choices = click.Choice(
98+
choices=["zarr", "AnnotationStore"], case_sensitive=False
99+
)
100+
input_type = click_choices if input_type is None else input_type
99101
return click.option(
100102
"--output-type",
101103
help=add_default_to_usage_help(usage_help, default),
@@ -410,20 +412,6 @@ def cli_yaml_config_path(
410412
)
411413

412414

413-
def cli_on_gpu(
414-
usage_help: str = "Run the model on GPU.",
415-
*,
416-
default: bool = False,
417-
) -> Callable:
418-
"""Enables --on-gpu option for cli."""
419-
return click.option(
420-
"--on-gpu",
421-
type=bool,
422-
default=default,
423-
help=add_default_to_usage_help(usage_help, default),
424-
)
425-
426-
427415
def cli_num_loader_workers(
428416
usage_help: str = "Number of workers to load the data. Please note that they will "
429417
"also perform preprocessing.",

tiatoolbox/cli/patch_predictor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
import click
6-
75
from tiatoolbox.cli.common import (
86
cli_batch_size,
97
cli_device,
@@ -45,7 +43,6 @@
4543
@cli_num_loader_workers(default=0)
4644
@cli_output_type(
4745
default="AnnotationStore",
48-
input_type=click.Choice(["zarr", "AnnotationStore"], case_sensitive=False),
4946
)
5047
@cli_patch_mode(default=False)
5148
@cli_return_probabilities(default=True)

0 commit comments

Comments
 (0)