Skip to content

Commit 6ad56fe

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 35237c4 commit 6ad56fe

File tree

7 files changed

+376
-242
lines changed

7 files changed

+376
-242
lines changed

examples/sam-architecture.ipynb

Lines changed: 216 additions & 111 deletions
Large diffs are not rendered by default.

examples/tiaviz-test.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
}
2020
],
2121
"source": [
22+
"from tiatoolbox.models.architecture.sam import SAM\n",
2223
"from tiatoolbox.models.engine.general_segmentor import GeneralSegmentor\n",
23-
"from tiatoolbox.models.architecture.sam import SAM, SAMPrompts \n",
2424
"\n",
2525
"# abc = GeneralSegmentor(model=SAM())\n",
2626
"# prompts = SAMPrompts([[100,100]])\n",
@@ -42,14 +42,15 @@
4242
],
4343
"source": [
4444
"from pathlib import Path\n",
45+
"\n",
4546
"model = GeneralSegmentor(SAM())\n",
4647
"\n",
4748
"glands = \"slides/glands.png\"\n",
4849
"slides = \"slides/sample_wsi.svs\"\n",
49-
"glands_prompts = [(370,270),(300,400)]\n",
50-
"slides_prompts = [[5792,6018]]\n",
51-
"slides_location = (5745,5972)\n",
52-
"slides_size = (200,114)"
50+
"glands_prompts = [(370, 270), (300, 400)]\n",
51+
"slides_prompts = [[5792, 6018]]\n",
52+
"slides_location = (5745, 5972)\n",
53+
"slides_size = (200, 114)"
5354
]
5455
},
5556
{
@@ -74,8 +75,9 @@
7475
],
7576
"source": [
7677
"prompts = model.create_prompts(slides_prompts)\n",
77-
"output = model.predict(slides, prompts, \"cpu\", \"overlays\", slides_location,slides_size, 0.5)\n",
78-
"\n"
78+
"output = model.predict(\n",
79+
" slides, prompts, \"cpu\", \"overlays\", slides_location, slides_size, 0.5\n",
80+
")"
7981
]
8082
},
8183
{
@@ -84,9 +86,7 @@
8486
"metadata": {},
8587
"outputs": [],
8688
"source": [
87-
"\n",
88-
"save_path = model.to_annotation(output[0][1], output[0][2], Path(\"overlays/sample_wsi\"))\n",
89-
"\n"
89+
"save_path = model.to_annotation(output[0][1], output[0][2], Path(\"overlays/sample_wsi\"))"
9090
]
9191
}
9292
],

tests/models/test_arch_sam.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33
from pathlib import Path
44
from typing import Callable
55

6-
import numpy as np
76
import pytest
8-
import torch
97

108
from tiatoolbox.models import SAM
119
from tiatoolbox.models.architecture.sam import SAMPrompts
12-
from tiatoolbox.models.architecture import fetch_pretrained_weights
1310
from tiatoolbox.utils import imread
1411

1512
ON_GPU = False
@@ -32,21 +29,21 @@ def test_functional_sam(
3229
# test inference
3330
# create prompts
3431

35-
prompts1 = SAMPrompts(point_coords=[[64,64]])
36-
prompts2 = SAMPrompts(point_coords=[[64,64]], point_labels=[1])
37-
prompts3 = SAMPrompts(box_coords=[[64,64,128,128]])
38-
prompts4 = SAMPrompts(point_coords=[[64,64]], point_labels=[1], box_coords=[[64,64,128,128]])
32+
prompts1 = SAMPrompts(point_coords=[[64, 64]])
33+
prompts2 = SAMPrompts(point_coords=[[64, 64]], point_labels=[1])
34+
prompts3 = SAMPrompts(box_coords=[[64, 64, 128, 128]])
35+
prompts4 = SAMPrompts(
36+
point_coords=[[64, 64]], point_labels=[1], box_coords=[[64, 64, 128, 128]]
37+
)
3938

4039
model = SAM()
4140

4241
# load pretrained weights
43-
#pretrained = torch.load(weights_path, map_location="cpu")
44-
#model.load_state_dict(pretrained)
42+
# pretrained = torch.load(weights_path, map_location="cpu")
43+
# model.load_state_dict(pretrained)
4544

46-
_ = model.infer_batch(model, img, on_gpu=ON_GPU) # no prompts
45+
_ = model.infer_batch(model, img, on_gpu=ON_GPU) # no prompts
4746
_ = model.infer_batch(model, img, prompts=prompts1, on_gpu=ON_GPU)
4847
_ = model.infer_batch(model, img, prompts=prompts2, on_gpu=ON_GPU)
4948
_ = model.infer_batch(model, img, prompts=prompts3, on_gpu=ON_GPU)
5049
_ = model.infer_batch(model, img, prompts=prompts4, on_gpu=ON_GPU)
51-
52-

tiatoolbox/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from .architecture.mapde import MapDe
99
from .architecture.micronet import MicroNet
1010
from .architecture.nuclick import NuClick
11-
from .architecture.sccnn import SCCNN
1211
from .architecture.sam import SAM
12+
from .architecture.sccnn import SCCNN
13+
from .engine.general_segmentor import GeneralSegmentor
1314
from .engine.multi_task_segmentor import MultiTaskSegmentor
1415
from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor
15-
from .engine.general_segmentor import GeneralSegmentor
1616
from .engine.patch_predictor import (
1717
IOPatchPredictorConfig,
1818
PatchDataset,
@@ -27,17 +27,17 @@
2727
)
2828

2929
__all__ = [
30+
"SAM",
3031
"SCCNN",
32+
"GeneralSegmentor",
3133
"HoVerNet",
3234
"HoVerNetPlus",
3335
"IDaRS",
3436
"MapDe",
3537
"MicroNet",
3638
"MultiTaskSegmentor",
3739
"NuClick",
38-
"SAM",
3940
"NucleusInstanceSegmentor",
4041
"PatchPredictor",
4142
"SemanticSegmentor",
42-
"GeneralSegmentor",
4343
]

tiatoolbox/models/architecture/sam.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,24 @@
22

33
from __future__ import annotations
44

5-
from collections import OrderedDict
6-
75
import numpy as np
86
import torch
9-
import torch.nn.functional as F # noqa: N812
10-
from skimage import morphology
11-
from torch import nn
12-
13-
from tiatoolbox.utils import misc
14-
from tiatoolbox.models.models_abc import ModelABC
15-
7+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
168
from sam2.build_sam import build_sam2, build_sam2_hf
179
from sam2.sam2_image_predictor import SAM2ImagePredictor
18-
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
10+
11+
from tiatoolbox.models.models_abc import ModelABC
1912

2013

21-
class SAMPrompts():
14+
class SAMPrompts:
2215
"""Structure of prompts for SAM."""
23-
def __init__(self, point_coords = None, point_labels = None, box_coords = None):
16+
17+
def __init__(self, point_coords=None, point_labels=None, box_coords=None):
2418
self.point_coords = None if point_coords == [] else point_coords
2519
self.box_coords = None if box_coords == [] else box_coords
26-
if(point_coords and point_labels is None):
20+
if point_coords and point_labels is None:
2721
self.point_labels = [1] * len(point_coords)
28-
else:
22+
else:
2923
self.point_labels = point_labels
3024

3125

@@ -44,15 +38,15 @@ def __init__(
4438
self.model = build_sam2_hf(model_hf_path, device="cpu")
4539
else:
4640
self.model = build_sam2(model_cfg_path, checkpoint_path)
47-
41+
4842
self.predictor = SAM2ImagePredictor(self.model)
4943
self.generator = SAM2AutomaticMaskGenerator(self.model)
5044

5145
def forward(self: SAM, image: np.ndarray, prompts: SAMPrompts = None) -> np.ndarray:
5246
"""Torch method, this contains logic for using layers defined in init."""
5347
mask = self.generate_mask(self, image, prompts)
5448
return mask
55-
49+
5650
@staticmethod
5751
def infer_batch(
5852
model: torch.nn.Module,
@@ -78,7 +72,9 @@ def infer_batch(
7872
model.eval()
7973
model = model.to(device)
8074

81-
if isinstance(batch_data, torch.Tensor): # Move the tensor to the CPU if it's a PyTorch tensor
75+
if isinstance(
76+
batch_data, torch.Tensor
77+
): # Move the tensor to the CPU if it's a PyTorch tensor
8278
batch_data = batch_data.to(device).type(torch.float32)
8379
batch_data = batch_data.cpu().numpy()
8480

@@ -92,7 +88,7 @@ def infer_batch(
9288
def encode_image(self, image: np.ndarray) -> np.ndarray:
9389
"""Encodes the image for feature extraction."""
9490
self.predictor.set_image(image)
95-
91+
9692
@staticmethod
9793
def generate_mask(self, features: np.ndarray, prompts: SAMPrompts) -> np.ndarray:
9894
"""Generates a segmentation mask using SAM 2, optionally guided by a prompt."""
@@ -109,24 +105,33 @@ def generate_mask(self, features: np.ndarray, prompts: SAMPrompts) -> np.ndarray
109105
scores = np.around(scores[sorted_ind], 2)
110106
else:
111107
masks = self.generator.generate(features)
112-
scores = np.array([mask['predicted_iou'] for mask in masks])
108+
scores = np.array([mask["predicted_iou"] for mask in masks])
113109
return masks, scores
114-
110+
115111
@staticmethod
116112
def load_weights(self, checkpoint_path: str) -> None:
117113
"""Loads model weights from specified checkpoint."""
118-
self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
114+
self.model.load_state_dict(
115+
torch.load(checkpoint_path, map_location=self.device)
116+
)
119117

120118
@staticmethod
121119
def preproc(image: np.ndarray) -> np.ndarray:
122120
"""Pre-processes images - Converts them into a format accepted by SAM (HWC) from NCHW."""
123-
if isinstance(image, torch.Tensor): # Move the tensor to the CPU if it's a PyTorch tensor
121+
if isinstance(
122+
image, torch.Tensor
123+
): # Move the tensor to the CPU if it's a PyTorch tensor
124124
image = image.cpu().numpy()
125-
125+
126126
# Handle different shapes
127-
if image.ndim == 4 and image.shape == (1,512,512,3): # Case 1: (N, H, W, C)
127+
if image.ndim == 4 and image.shape == (1, 512, 512, 3): # Case 1: (N, H, W, C)
128128
image = np.squeeze(image, axis=0) # Remove batch dimension
129-
elif image.ndim == 4 and image.shape == (1,3,512,512): # Case 2: (N, C, H, W)
129+
elif image.ndim == 4 and image.shape == (
130+
1,
131+
3,
132+
512,
133+
512,
134+
): # Case 2: (N, C, H, W)
130135
image = np.squeeze(image, axis=0) # Remove batch dimension
131136
image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C)
132137

@@ -136,4 +141,4 @@ def preproc(image: np.ndarray) -> np.ndarray:
136141
@staticmethod
137142
def postproc(image: np.ndarray) -> np.ndarray:
138143
"""Define the post-processing of this class of model."""
139-
return image
144+
return image

0 commit comments

Comments
 (0)