Skip to content

Commit 96676c8

Browse files
committed
✨ Add CLI interface
1 parent 5997671 commit 96676c8

File tree

7 files changed

+175
-130
lines changed

7 files changed

+175
-130
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import numpy as np
1111
import torch
1212
import zarr
13+
from click.testing import CliRunner
1314

15+
from tiatoolbox import cli
1416
from tiatoolbox.annotation import SQLiteStore
1517
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
1618
from tiatoolbox.utils import env_detection as toolbox_env
@@ -354,3 +356,28 @@ def test_wsi_segmentor_annotationstore(
354356
zarr_group = zarr.open(output[sample_svs].with_suffix(".zarr"), mode="r")
355357
assert "probabilities" in zarr_group
356358
assert "Probability maps cannot be saved as AnnotationStore." in caplog.text
359+
360+
361+
# -------------------------------------------------------------------------------------
362+
# Command Line Interface
363+
# -------------------------------------------------------------------------------------
364+
365+
366+
def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None:
367+
"""Test for models CLI single file."""
368+
runner = CliRunner()
369+
models_wsi_result = runner.invoke(
370+
cli.main,
371+
[
372+
"semantic-segmentor",
373+
"--img-input",
374+
str(sample_svs),
375+
"--patch-mode",
376+
"False",
377+
"--output-path",
378+
str(tmp_path / "output"),
379+
],
380+
)
381+
382+
assert models_wsi_result.exit_code == 0
383+
assert (tmp_path / "output" / (sample_svs.stem + ".db")).exists()

tiatoolbox/cli/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tiatoolbox.cli.patch_predictor import patch_predictor
1212
from tiatoolbox.cli.read_bounds import read_bounds
1313
from tiatoolbox.cli.save_tiles import save_tiles
14-
from tiatoolbox.cli.semantic_segment import semantic_segment
14+
from tiatoolbox.cli.semantic_segmentor import semantic_segmentor
1515
from tiatoolbox.cli.show_wsi import show_wsi
1616
from tiatoolbox.cli.slide_info import slide_info
1717
from tiatoolbox.cli.slide_thumbnail import slide_thumbnail
@@ -42,7 +42,7 @@ def main() -> click.BaseCommand:
4242
main.add_command(patch_predictor)
4343
main.add_command(read_bounds)
4444
main.add_command(save_tiles)
45-
main.add_command(semantic_segment)
45+
main.add_command(semantic_segmentor)
4646
main.add_command(slide_info)
4747
main.add_command(slide_thumbnail)
4848
main.add_command(tissue_mask)

tiatoolbox/cli/common.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,28 @@ def cli_masks(
387387
)
388388

389389

390-
def cli_auto_generate_mask(
390+
def cli_memory_threshold(
391+
usage_help: str = (
392+
"Memory usage threshold (in percentage) to trigger caching behavior."
393+
),
394+
default: int = 80,
395+
) -> Callable:
396+
"""Enables --batch-size option for cli."""
397+
return click.option(
398+
"--memory-threshold",
399+
help=add_default_to_usage_help(usage_help, default=default),
400+
default=default,
401+
)
402+
403+
404+
def cli_auto_get_mask(
391405
usage_help: str = "Automatically generate tile/WSI tissue mask.",
392406
*,
393407
default: bool = False,
394408
) -> Callable:
395409
"""Enables --auto-generate-mask option for cli."""
396410
return click.option(
397-
"--auto-generate-mask",
411+
"--auto-get-mask",
398412
help=add_default_to_usage_help(usage_help, default=default),
399413
type=bool,
400414
default=default,
@@ -415,27 +429,14 @@ def cli_yaml_config_path(
415429
)
416430

417431

418-
def cli_num_loader_workers(
432+
def cli_num_workers(
419433
usage_help: str = "Number of workers to load the data. Please note that they will "
420434
"also perform preprocessing.",
421435
default: int = 0,
422436
) -> Callable:
423437
"""Enables --num-loader-workers option for cli."""
424438
return click.option(
425-
"--num-loader-workers",
426-
help=add_default_to_usage_help(usage_help, default=default),
427-
type=int,
428-
default=default,
429-
)
430-
431-
432-
def cli_num_postproc_workers(
433-
usage_help: str = "Number of workers to post-process the network output.",
434-
default: int = 0,
435-
) -> Callable:
436-
"""Enables --num-postproc-workers option for cli."""
437-
return click.option(
438-
"--num-postproc-workers",
439+
"--num-workers",
439440
help=add_default_to_usage_help(usage_help, default=default),
440441
type=int,
441442
default=default,

tiatoolbox/cli/nucleus_instance_segment.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
import click
66

77
from tiatoolbox.cli.common import (
8-
cli_auto_generate_mask,
8+
cli_auto_get_mask,
99
cli_batch_size,
1010
cli_device,
1111
cli_file_type,
1212
cli_img_input,
1313
cli_masks,
1414
cli_mode,
15-
cli_num_loader_workers,
16-
cli_num_postproc_workers,
15+
cli_num_workers,
1716
cli_output_path,
1817
cli_pretrained_model,
1918
cli_pretrained_weights,
@@ -45,10 +44,9 @@
4544
@cli_batch_size()
4645
@cli_masks(default=None)
4746
@cli_yaml_config_path(default=None)
48-
@cli_num_loader_workers()
47+
@cli_num_workers()
4948
@cli_verbose(default=True)
50-
@cli_num_postproc_workers(default=0)
51-
@cli_auto_generate_mask(default=False)
49+
@cli_auto_get_mask(default=False)
5250
def nucleus_instance_segment(
5351
pretrained_model: str,
5452
pretrained_weights: str,
@@ -60,7 +58,6 @@ def nucleus_instance_segment(
6058
batch_size: int,
6159
yaml_config_path: str,
6260
num_loader_workers: int,
63-
num_postproc_workers: int,
6461
device: str,
6562
*,
6663
auto_generate_mask: bool,
@@ -91,7 +88,6 @@ def nucleus_instance_segment(
9188
pretrained_weights=pretrained_weights,
9289
batch_size=batch_size,
9390
num_loader_workers=num_loader_workers,
94-
num_postproc_workers=num_postproc_workers,
9591
auto_generate_mask=auto_generate_mask,
9692
verbose=verbose,
9793
)

tiatoolbox/cli/patch_predictor.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from __future__ import annotations
44

55
from tiatoolbox.cli.common import (
6+
cli_auto_get_mask,
67
cli_batch_size,
78
cli_device,
89
cli_file_type,
910
cli_img_input,
1011
cli_masks,
12+
cli_memory_threshold,
1113
cli_model,
12-
cli_num_loader_workers,
14+
cli_num_workers,
1315
cli_output_path,
1416
cli_output_type,
1517
cli_patch_mode,
@@ -39,13 +41,15 @@
3941
@cli_batch_size(default=1)
4042
@cli_yaml_config_path()
4143
@cli_masks(default=None)
42-
@cli_num_loader_workers(default=0)
44+
@cli_num_workers(default=0)
4345
@cli_output_type(
4446
default="AnnotationStore",
4547
)
48+
@cli_memory_threshold(default=80)
4649
@cli_patch_mode(default=False)
4750
@cli_return_probabilities(default=True)
4851
@cli_return_labels(default=False)
52+
@cli_auto_get_mask(default=True)
4953
@cli_verbose(default=True)
5054
def patch_predictor(
5155
model: str,
@@ -56,16 +60,18 @@ def patch_predictor(
5660
output_path: str,
5761
batch_size: int,
5862
yaml_config_path: str,
59-
num_loader_workers: int,
63+
num_workers: int,
6064
device: str,
6165
output_type: str,
66+
memory_threshold: int,
6267
*,
68+
patch_mode: bool,
6369
return_probabilities: bool,
6470
return_labels: bool,
65-
patch_mode: bool,
71+
auto_get_mask: bool,
6672
verbose: bool,
6773
) -> None:
68-
"""Process an image/directory of input images with a patch classification CNN."""
74+
"""Process an image/directory of input images with a patch classification engine."""
6975
from tiatoolbox.models.engine.io_config import ( # noqa: PLC0415
7076
IOPatchPredictorConfig,
7177
)
@@ -82,7 +88,7 @@ def patch_predictor(
8288
model=model,
8389
weights=weights,
8490
batch_size=batch_size,
85-
num_workers=num_loader_workers,
91+
num_workers=num_workers,
8692
verbose=verbose,
8793
)
8894

@@ -102,4 +108,6 @@ def patch_predictor(
102108
output_type=output_type,
103109
return_probabilities=return_probabilities,
104110
return_labels=return_labels,
111+
auto_get_mask=auto_get_mask,
112+
memory_threshold=memory_threshold,
105113
)

tiatoolbox/cli/semantic_segment.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

0 commit comments

Comments
 (0)