Skip to content

Commit a4e3a53

Browse files
Merge pull request #852 from computational-cell-analytics/dev
Prepare release 1.2.1
2 parents 1bc787b + aa54b05 commit a4e3a53

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+866
-375
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Segment Anything for Histopathology
2+
3+
This is a [Segment Anything]https://segment-anything.com/) model that was specialized for histopathology with [micro_sam](https://github.com/computational-cell-analytics/micro-sam).
4+
This model uses a %s vision transformer as image encoder.
5+
6+
Segment Anything is a model for interactive and automatic instance segmentation.
7+
We improve it for histopathology by finetuning on a large and diverse microscopy dataset.
8+
It should perform well for nucleus segmentation in histopathology datasets.
9+
10+
See [the dataset overview](https://github.com/computational-cell-analytics/micro-sam/blob/master/doc/datasets/histopathology_v%i.md) for further informations on the training data and the [micro_sam documentation](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html) for details on how to use the model for interactive and automatic segmentation.
11+
12+
## Validation
13+
14+
The easiest way to validate the model is to visually check the segmentation quality for your data.
15+
If you have annotations you can use for validation you can also quantitative validation, see [here for details](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#9-how-can-i-evaluate-a-model-i-have-finetuned).
16+
Please note that the required quality for segmentation always depends on the analysis task you want to solve.

finetuning/livecell_finetuning.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ def finetune_livecell(args):
5656
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
5757
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True}
5858

59+
# NOTE: memory req. for all vit_b models (compared on A100 80GB)
60+
# vit_b
61+
# freeze_encoder: ~ 33.89 GB
62+
# QLoRA: ~48.54 GB
63+
# LoRA: ~48.62 GB
64+
# FFT: ~49.56 GB
65+
66+
# vit_h
67+
# freeze_encoder: ~36.05 GB
68+
# QLoRA: ~ 65.68 GB
69+
# LoRA: ~ 67.14 GB
70+
# FFT: ~72.34 GB
71+
5972
# Run training.
6073
sam_training.train_sam(
6174
name=checkpoint_name,
@@ -72,7 +85,7 @@ def finetune_livecell(args):
7285
save_root=args.save_root,
7386
scheduler_kwargs=scheduler_kwargs,
7487
save_every_kth_epoch=args.save_every_kth_epoch,
75-
peft_kwargs={"rank": args.lora_rank} if args.lora_rank is not None else None,
88+
peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None,
7689
)
7790

7891
if args.export_path is not None:
@@ -87,7 +100,7 @@ def finetune_livecell(args):
87100
def main():
88101
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.")
89102
parser.add_argument(
90-
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
103+
"--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/data/livecell/",
91104
help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded."
92105
)
93106
parser.add_argument(

micro_sam/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
.. include:: ../doc/contributing.md
1010
.. include:: ../doc/band.md
1111
"""
12+
1213
import os
1314

1415
from .__version__ import __version__

micro_sam/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.2.0"
1+
__version__ = "1.2.1"

micro_sam/_vendored.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
"""
2-
Functions from other third party libraries.
1+
"""Functions from other third party libraries.
32
43
We can remove these functions once the bugs affecting our code is fixed upstream.
54
65
The license type of the thrid party software project must be compatible with
76
the software license the micro-sam project is distributed under.
87
"""
98

10-
from typing import Any, Dict, List
9+
from typing import Any, Dict, List, Literal
1110

1211
import numpy as np
1312

@@ -109,7 +108,9 @@ def _compute_rle_numpy(mask):
109108
return counts
110109

111110

112-
def mask_to_rle_pytorch(tensor: torch.Tensor, rle_implementation: str = "default") -> List[Dict[str, Any]]:
111+
def mask_to_rle_pytorch(
112+
tensor: torch.Tensor, rle_implementation: Literal["default", "numpy", "numba", "nifty"] = "default"
113+
) -> List[Dict[str, Any]]:
113114
"""Calculates the runlength encoding of binary input masks.
114115
115116
This replaces the function in

micro_sam/automatic_segmentation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def automatic_instance_segmentation(
120120
verbose=verbose,
121121
)
122122

123-
segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
123+
# If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
124+
if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
125+
generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
126+
127+
segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
124128
masks = segmenter.generate(**generate_kwargs)
125129

126130
if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels

micro_sam/bioimageio/bioengine_export.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional, Union
44

55
import torch
6+
67
from segment_anything.utils.onnx import SamOnnxModel
78

89
try:
@@ -67,7 +68,7 @@ def export_image_encoder(
6768
model_type: str,
6869
output_root: Union[str, os.PathLike],
6970
export_name: Optional[str] = None,
70-
checkpoint_path: Optional[str] = None,
71+
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
7172
) -> None:
7273
"""Export SAM image encoder to torchscript.
7374
@@ -103,15 +104,16 @@ def export_image_encoder(
103104

104105

105106
def export_onnx_model(
106-
model_type,
107-
output_root,
108-
opset: int,
107+
model_type: str,
108+
output_root: Union[str, os.PathLike],
109+
opset: int = 17,
109110
export_name: Optional[str] = None,
110111
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
111112
return_single_mask: bool = True,
112113
gelu_approximate: bool = False,
113114
use_stability_score: bool = False,
114115
return_extra_metrics: bool = False,
116+
quantize_model: bool = False,
115117
) -> None:
116118
"""Export SAM prompt encoder and mask decoder to onnx.
117119
@@ -122,14 +124,16 @@ def export_onnx_model(
122124
Args:
123125
model_type: The SAM model type.
124126
output_root: The output root directory where the exported model is saved.
125-
opset: The ONNX opset version.
127+
opset: The ONNX opset version. The recommended opset version is 17.
126128
export_name: The name of the exported model.
127129
checkpoint_path: Optional checkpoint for loading the SAM model.
128130
return_single_mask: Whether the mask decoder returns a single or multiple masks.
129131
gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
130132
does not have an efficient GeLU implementation.
131133
use_stability_score: Whether to use the stability score instead of the predicted score.
132134
return_extra_metrics: Whether to return a larger set of metrics.
135+
quantize_model: Whether to also export a quantized version of the model.
136+
This only works for onnxruntime < 1.17.
133137
"""
134138
if export_name is None:
135139
export_name = model_type
@@ -154,10 +158,7 @@ def export_onnx_model(
154158
if isinstance(m, torch.nn.GELU):
155159
m.approximate = "tanh"
156160

157-
dynamic_axes = {
158-
"point_coords": {1: "num_points"},
159-
"point_labels": {1: "num_points"},
160-
}
161+
dynamic_axes = {"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}}
161162

162163
embed_dim = sam.prompt_encoder.embed_dim
163164
embed_size = sam.prompt_encoder.image_embedding_size
@@ -202,14 +203,31 @@ def export_onnx_model(
202203
_ = ort_session.run(None, ort_inputs)
203204
print("Model has successfully been run with ONNXRuntime.")
204205

206+
# This requires onnxruntime < 1.17.
207+
# See https://github.com/facebookresearch/segment-anything/issues/699#issuecomment-1984670808
208+
if quantize_model:
209+
assert onnxruntime_exists
210+
from onnxruntime.quantization import QuantType
211+
from onnxruntime.quantization.quantize import quantize_dynamic
212+
213+
quantized_path = os.path.join(weight_output_folder, "model_quantized.onnx")
214+
quantize_dynamic(
215+
model_input=weight_path,
216+
model_output=quantized_path,
217+
# optimize_model=True,
218+
per_channel=False,
219+
reduce_range=False,
220+
weight_type=QuantType.QUInt8,
221+
)
222+
205223
config_output_path = os.path.join(output_folder, "config.pbtxt")
206224
with open(config_output_path, "w") as f:
207225
f.write(DECODER_CONFIG % name)
208226

209227

210228
def export_bioengine_model(
211-
model_type,
212-
output_root,
229+
model_type: str,
230+
output_root: Union[str, os.PathLike],
213231
opset: int,
214232
export_name: Optional[str] = None,
215233
checkpoint_path: Optional[Union[str, os.PathLike]] = None,

micro_sam/bioimageio/model_export.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@
3333
"tags": ["segment-anything", "instance-segmentation"],
3434
}
3535

36+
# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269 # noqa
37+
# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter.
38+
ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1)
39+
40+
41+
def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir):
3642

37-
def _create_test_inputs_and_outputs(
38-
image,
39-
labels,
40-
model_type,
41-
checkpoint_path,
42-
tmp_dir,
43-
):
4443
# For now we just generate a single box prompt here, but we could also generate more input prompts.
4544
generator = PointAndBoxPromptGenerator(
4645
n_positive_points=1,
@@ -59,10 +58,7 @@ def _create_test_inputs_and_outputs(
5958

6059
# Generate logits from the two
6160
mask_prompts = np.stack(
62-
[
63-
_compute_logits_from_mask(labels == 1),
64-
_compute_logits_from_mask(labels == 2),
65-
]
61+
[_compute_logits_from_mask(labels == 1), _compute_logits_from_mask(labels == 2)]
6662
)[None]
6763

6864
predictor = PredictorAdaptor(model_type=model_type)
@@ -104,11 +100,7 @@ def _create_test_inputs_and_outputs(
104100
"point_labels": point_label_path,
105101
"mask_prompts": mask_prompt_path,
106102
}
107-
outputs = {
108-
"mask": mask_path,
109-
"score": score_path,
110-
"embeddings": embed_path
111-
}
103+
outputs = {"mask": mask_path, "score": score_path, "embeddings": embed_path}
112104
return inputs, outputs
113105

114106

@@ -161,6 +153,7 @@ def _get_checkpoint(model_type, checkpoint_path, tmp_dir):
161153
return checkpoint_path, None
162154

163155

156+
# TODO: Update this with our latest yaml file updates.
164157
def _write_dependencies(dependency_file, require_mobile_sam):
165158
content = """name: sam
166159
channels:
@@ -215,7 +208,7 @@ def _check_model(model_description, input_paths, result_paths):
215208
image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx"))
216209
embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx"))
217210
box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic"))
218-
point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic"))
211+
point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("bhwc"))
219212
point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic"))
220213
mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx"))
221214

@@ -303,8 +296,8 @@ def export_sam_model(
303296
# NOTE: to support 1 and 3 channels we can add another preprocessing.
304297
# Best solution: Have a pre-processing for this! (1C -> RGB)
305298
spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
306-
spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE),
307-
spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE),
299+
spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
300+
spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
308301
],
309302
test_tensor=spec.FileDescr(source=input_paths["image"]),
310303
data=spec.IntervalOrRatioDataDescr(type="uint8")
@@ -318,7 +311,7 @@ def export_sam_model(
318311
spec.BatchAxis(size=1),
319312
spec.IndexInputAxis(
320313
id=spec.AxisId("object"),
321-
size=spec.ARBITRARY_SIZE
314+
size=ARBITRARY_SIZE
322315
),
323316
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
324317
],
@@ -334,11 +327,11 @@ def export_sam_model(
334327
spec.BatchAxis(size=1),
335328
spec.IndexInputAxis(
336329
id=spec.AxisId("object"),
337-
size=spec.ARBITRARY_SIZE
330+
size=ARBITRARY_SIZE
338331
),
339332
spec.IndexInputAxis(
340333
id=spec.AxisId("point"),
341-
size=spec.ARBITRARY_SIZE
334+
size=ARBITRARY_SIZE
342335
),
343336
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
344337
],
@@ -354,11 +347,11 @@ def export_sam_model(
354347
spec.BatchAxis(size=1),
355348
spec.IndexInputAxis(
356349
id=spec.AxisId("object"),
357-
size=spec.ARBITRARY_SIZE
350+
size=ARBITRARY_SIZE
358351
),
359352
spec.IndexInputAxis(
360353
id=spec.AxisId("point"),
361-
size=spec.ARBITRARY_SIZE
354+
size=ARBITRARY_SIZE
362355
),
363356
],
364357
test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
@@ -373,7 +366,7 @@ def export_sam_model(
373366
spec.BatchAxis(size=1),
374367
spec.IndexInputAxis(
375368
id=spec.AxisId("object"),
376-
size=spec.ARBITRARY_SIZE
369+
size=ARBITRARY_SIZE
377370
),
378371
spec.ChannelAxis(channel_names=["channel"]),
379372
spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),

micro_sam/bioimageio/predictor_adaptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def forward(
5555
embeddings: precomputed image embeddings B x 256 x 64 x 64
5656
5757
Returns:
58+
The segmentation masks.
59+
The scores for prediction quality.
60+
The computed image embeddings.
5861
"""
5962
batch_size = image.shape[0]
6063
if batch_size != 1:

micro_sam/evaluation/experiments.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111

1212

1313
def full_experiment_settings(
14-
use_boxes: bool = False,
15-
positive_range: Optional[List[int]] = None,
16-
negative_range: Optional[List[int]] = None,
14+
use_boxes: bool = False, positive_range: Optional[List[int]] = None, negative_range: Optional[List[int]] = None,
1715
) -> ExperimentSettings:
1816
"""The full experiment settings.
1917

0 commit comments

Comments
 (0)