Skip to content

Commit 32d313f

Browse files
Merge pull request #83 from computational-cell-analytics/mypy
Fix some issues raised by mypy
2 parents 4ce846c + 6817c19 commit 32d313f

File tree

3 files changed

+20
-26
lines changed

3 files changed

+20
-26
lines changed

micro_sam/instance_segmentation.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import multiprocessing as mp
22
import warnings
33
from abc import ABC
4-
from collections.abc import Mapping
54
from concurrent import futures
65
from copy import deepcopy
7-
from typing import Any, List, Optional
6+
from typing import Any, Dict, List, Optional
87

98
import numpy as np
109
import torch
@@ -46,7 +45,7 @@ def __getitem__(self, index):
4645

4746

4847
def mask_data_to_segmentation(
49-
masks: List[Mapping[str, Any]],
48+
masks: List[Dict[str, Any]],
5049
shape: tuple[int, ...],
5150
with_background: bool,
5251
) -> np.ndarray:
@@ -256,7 +255,7 @@ def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, cr
256255

257256
return curr_anns
258257

259-
def get_state(self) -> Mapping[str, Any]:
258+
def get_state(self) -> Dict[str, Any]:
260259
"""Get the initialized state of the mask generator.
261260
262261
Returns:
@@ -266,7 +265,7 @@ def get_state(self) -> Mapping[str, Any]:
266265
raise RuntimeError("The state has not been computed yet. Call initialize first.")
267266
return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
268267

269-
def set_state(self, state: Mapping[str, Any]) -> None:
268+
def set_state(self, state: Dict[str, Any]) -> None:
270269
"""Set the state of the mask generator.
271270
272271
Args:
@@ -447,7 +446,7 @@ def generate(
447446
crop_nms_thresh: float = 0.7,
448447
min_mask_region_area: int = 0,
449448
output_mode: str = "binary_mask",
450-
) -> List[Mapping[str, Any]]:
449+
) -> List[Dict[str, Any]]:
451450
"""Generate instance segmentation for the currently initialized image.
452451
453452
Args:
@@ -486,7 +485,7 @@ def generate(
486485
data["boxes"].float(),
487486
scores,
488487
torch.zeros_like(data["boxes"][:, 0]), # categories
489-
iou_threshold=self.crop_nms_thresh,
488+
iou_threshold=crop_nms_thresh,
490489
)
491490
data.filter(keep_by_nms)
492491

@@ -648,7 +647,7 @@ def generate(
648647
box_nms_thresh: float = 0.7,
649648
min_mask_region_area: int = 0,
650649
output_mode: str = "binary_mask",
651-
) -> List[Mapping[str, Any]]:
650+
) -> List[Dict[str, Any]]:
652651
"""Generate instance segmentation for the currently initialized image.
653652
654653
Args:
@@ -699,7 +698,7 @@ def get_initial_segmentation(self) -> np.ndarray:
699698
raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
700699
return self._resize_segmentation(self._initial_segmentation, self.original_size)
701700

702-
def get_state(self) -> Mapping[str, Any]:
701+
def get_state(self) -> Dict[str, Any]:
703702
"""Get the initialized state of the mask generator.
704703
705704
Returns:
@@ -709,7 +708,7 @@ def get_state(self) -> Mapping[str, Any]:
709708
state["initial_segmentation"] = self._initial_segmentation
710709
return state
711710

712-
def set_state(self, state: Mapping[str, Any]) -> None:
711+
def set_state(self, state: Dict[str, Any]) -> None:
713712
"""Set the state of the mask generator.
714713
715714
Args:
@@ -918,7 +917,7 @@ def segment_tile(_, tile_id):
918917

919918
return segmentation
920919

921-
def get_initial_segmentation(self) -> None:
920+
def get_initial_segmentation(self) -> np.ndarray:
922921
"""Get the initial instance segmentation.
923922
924923
Returns:
@@ -947,7 +946,7 @@ def segment_tile(_, tile_id):
947946
self._stitched_initial_segmentation = initial_segmentation
948947
return initial_segmentation
949948

950-
def get_state(self) -> Mapping[str, Any]:
949+
def get_state(self) -> Dict[str, Any]:
951950
"""Get the initialized state of the mask generator.
952951
953952
Returns:
@@ -958,7 +957,7 @@ def get_state(self) -> Mapping[str, Any]:
958957
state["halo"] = self._halo
959958
return state
960959

961-
def set_state(self, state: Mapping[str, Any]) -> None:
960+
def set_state(self, state: Dict[str, Any]) -> None:
962961
"""Set the state of the mask generator.
963962
964963
Args:

micro_sam/prompt_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
dilation_strength: int,
3838
get_point_prompts: bool = True,
3939
get_box_prompts: bool = False
40-
):
40+
) -> None:
4141
self.n_positive_points = n_positive_points
4242
self.n_negative_points = n_negative_points
4343
self.dilation_strength = dilation_strength

micro_sam/util.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import hashlib
22
import os
33
import warnings
4-
from collections.abc import Mapping
54
from shutil import copyfileobj
6-
from typing import Any, Optional
5+
from typing import Any, Callable, Dict, Optional, Tuple
76

7+
import imageio.v3 as imageio
88
import numpy as np
99
import requests
1010
import torch
@@ -17,11 +17,6 @@
1717

1818
from segment_anything import sam_model_registry, SamPredictor
1919

20-
try:
21-
import imageio.v2 as imageio
22-
except ImportError:
23-
import imageio
24-
2520
try:
2621
from napari.utils import progress as tqdm
2722
except ImportError:
@@ -339,9 +334,9 @@ def precompute_image_embeddings(
339334
save_path: Optional[str] = None,
340335
lazy_loading: bool = False,
341336
ndim: Optional[int] = None,
342-
tile_shape: Optional[tuple[int]] = None,
343-
halo: Optional[tuple[int]] = None,
344-
wrong_file_callback: Optional[callable] = None,
337+
tile_shape: Optional[Tuple[int]] = None,
338+
halo: Optional[Tuple[int]] = None,
339+
wrong_file_callback: Optional[Callable] = None,
345340
) -> ImageEmbeddings:
346341
"""Compute the image embeddings (output of the encoder) for the input.
347342
@@ -453,7 +448,7 @@ def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
453448
def get_centers_and_bounding_boxes(
454449
segmentation: np.ndarray,
455450
mode: str = "v"
456-
) -> tuple[Mapping[int, np.ndarray], Mapping[int, tuple]]:
451+
) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
457452
"""Returns the center coordinates of the foreground instances in the ground-truth.
458453
459454
Args:
@@ -500,7 +495,7 @@ def load_image_data(
500495
The image data.
501496
"""
502497
if key is None:
503-
image_data = imageio.imread(path) if ndim == 2 else imageio.volread(path)
498+
image_data = imageio.imread(path)
504499
else:
505500
with open_file(path, mode="r") as f:
506501
image_data = f[key]

0 commit comments

Comments
 (0)