Skip to content

Commit 62cfe01

Browse files
committed
🐛 Address Co-Pilot suggestions.
1 parent 227e317 commit 62cfe01

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

tests/engines/test_feature_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) ->
8989

9090
output_ = zarr.open(output[mini_wsi_svs], mode="r")
9191
assert len(output_["coordinates"].shape) == 2
92-
assert len(output_["probabilities"].shape)
92+
assert len(output_["probabilities"].shape) == 4
9393

9494

9595
@pytest.mark.parametrize(

tiatoolbox/models/engine/deep_feature_extractor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
def save_to_cache(
6969
probabilities: list[da.Array],
7070
coordinates: list[da.Array],
71-
probabilities_zarr: zarr.Array,
72-
coordinates_zarr: zarr.Array,
71+
probabilities_zarr: zarr.Array | None,
72+
coordinates_zarr: zarr.Array | None,
7373
save_path: str | Path = "temp.zarr",
7474
) -> tuple[zarr.Array, zarr.Array]:
7575
"""Save computed feature and coordinate arrays to Zarr cache.
@@ -217,7 +217,7 @@ def __init__(
217217
self.process_prediction_per_batch = False
218218

219219
def infer_wsi(
220-
self: SemanticSegmentor,
220+
self: DeepFeatureExtractor,
221221
dataloader: DataLoader,
222222
save_path: Path,
223223
**kwargs: Unpack[SemanticSegmentorRunParams],
@@ -251,7 +251,6 @@ def infer_wsi(
251251
# Default Memory threshold percentage is 80.
252252
memory_threshold = kwargs.get("memory_threshold", 80)
253253
vm = psutil.virtual_memory()
254-
_ = save_path
255254
keys = ["probabilities", "coordinates"]
256255
probabilities, coordinates = [], []
257256

@@ -376,7 +375,7 @@ def post_process_patches(
376375
return raw_predictions
377376

378377
def save_predictions(
379-
self: SemanticSegmentor,
378+
self: DeepFeatureExtractor,
380379
processed_predictions: dict,
381380
output_type: str,
382381
save_path: Path | None = None,
@@ -418,7 +417,7 @@ def save_predictions(
418417
)
419418

420419
def _update_run_params(
421-
self: SemanticSegmentor,
420+
self: DeepFeatureExtractor,
422421
images: list[os.PathLike | Path | WSIReader] | np.ndarray,
423422
masks: list[os.PathLike | Path] | np.ndarray | None = None,
424423
labels: list | None = None,
@@ -539,7 +538,7 @@ def run(
539538
540539
Raises:
541540
ValueError:
542-
If `output_type` is not "zarr".
541+
If `output_type` is not "zarr" or "dict".
543542
"""
544543
# return_probabilities is always True for FeatureExtractor.
545544
kwargs["return_probabilities"] = True

0 commit comments

Comments
 (0)