Skip to content

Commit ea84438

Browse files
Merge pull request #12 from computational-cell-analytics/misc
Expose model type as argument
2 parents 827b7ed + c10f3fe commit ea84438

File tree

5 files changed

+46
-11
lines changed

5 files changed

+46
-11
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ We implement napari applications for:
1515
This is an early beta version. Any feedback is welcome, but please be aware that the functionality is under active development and that several features are not finalized or thoroughly tested yet.
1616
Once the functionality has matured we plan to release the interactive annotation applications as [napari plugins](https://napari.org/stable/plugins/index.html).
1717

18+
If you run into any problems or have questions please open an issue or reach out via [image.sc](https://forum.image.sc/) using the tag `micro-sam` and tagging @constantinpape.
19+
1820

1921
## Functionality overview
2022

@@ -143,6 +145,7 @@ TODO link to video tutorial
143145

144146
- By default, the applications pre-compute the image embeddings produced by SegmentAnything and store them on disc. If you are using a CPU this step can take a while for 3d data or timeseries (you will see a progress bar with a time estimate). If you have access to a GPU without graphical interface (e.g. via a local computer cluster or a cloud provider), you can also pre-compute the embeddings there and then copy them to your laptop / local machine to speed this up. You can use the command `micro_sam.precompute_embeddings` for this (it is installed with the rest of the applications). You can specify the location of the precomputed embeddings via the `embedding_path` argument.
145147
- Most other processing steps are very fast even on a CPU, so interactive annotation is possible. An exception is the automatic segmentation step (2d segmentation), which takes several minutes without a GPU (depending on the image size). For large volumes and timeseries segmenting an object in 3d / tracking across time can take a couple settings with a CPU (it is very fast with a GPU).
148+
- You can also try using a smaller version of the SegmentAnything model to speed up the computations. For this you can pass the `model_type` argument and either set it to `vit_l` or `vit_b` (default is `vit_h`). However, this may lead to worse results.
146149
- You can save and load the results from the `committed_objects` / `committed_tracks` layer to correct segmentations you obtained from another tool (e.g. CellPose) or to save intermediate annotation results. The results can be saved via `File->Save Selected Layer(s) ...` in the napari menu (see the tutorial videos for details). They can be loaded again by specifying the corresponding location via the `segmentation_result` (2d and 3d segmentation) or `tracking_result` (tracking) argument.
147150

148151
### Known limitations

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def autosegment_widget(v: Viewer, method: str = "default"):
3636
v.layers["auto_segmentation"].refresh()
3737

3838

39-
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
39+
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None, model_type="vit_h"):
4040
# for access to the predictor and the image embeddings in the widgets
4141
global PREDICTOR, IMAGE_EMBEDDINGS, SAM
4242

43-
PREDICTOR, SAM = util.get_sam_model(return_sam=True)
43+
PREDICTOR, SAM = util.get_sam_model(model_type=model_type, return_sam=True)
4444
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path, ndim=2)
4545
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
4646

@@ -166,6 +166,9 @@ def main():
166166
"--show_embeddings", action="store_true",
167167
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
168168
)
169+
parser.add_argument(
170+
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
171+
)
169172

170173
args = parser.parse_args()
171174
raw = util.load_image_data(args.input, ndim=2, key=args.key)
@@ -180,5 +183,6 @@ def main():
180183

181184
annotator_2d(
182185
raw, embedding_path=args.embedding_path,
183-
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result
186+
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result,
187+
model_type=args.model_type,
184188
)

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, projection: str
162162
v.layers["current_object"].refresh()
163163

164164

165-
def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
165+
def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None, model_type="vit_h"):
166166
# for access to the predictor and the image embeddings in the widgets
167167
global PREDICTOR, IMAGE_EMBEDDINGS, DEFAULT_PROJECTION
168-
PREDICTOR = util.get_sam_model()
168+
PREDICTOR = util.get_sam_model(model_type=model_type)
169169
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
170170

171171
# the mask projection currently only works for square images
@@ -291,6 +291,9 @@ def main():
291291
"--show_embeddings", action="store_true",
292292
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
293293
)
294+
parser.add_argument(
295+
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
296+
)
294297

295298
args = parser.parse_args()
296299
raw = util.load_image_data(args.input, ndim=3, key=args.key)
@@ -305,5 +308,6 @@ def main():
305308

306309
annotator_3d(
307310
raw, embedding_path=args.embedding_path,
308-
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result
311+
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result,
312+
model_type=args.model_type,
309313
)

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,12 @@ def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
299299
v.layers["prompts"].refresh()
300300

301301

302-
def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking_result=None):
302+
def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking_result=None, model_type="vit_h"):
303303
# global state
304304
global PREDICTOR, IMAGE_EMBEDDINGS, CURRENT_TRACK_ID, LINEAGE
305305
global TRACKING_WIDGET
306306

307-
PREDICTOR = util.get_sam_model()
307+
PREDICTOR = util.get_sam_model(model_type=model_type)
308308
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
309309

310310
CURRENT_TRACK_ID = 1
@@ -445,6 +445,9 @@ def main():
445445
"--show_embeddings", action="store_true",
446446
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
447447
)
448+
parser.add_argument(
449+
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
450+
)
448451

449452
args = parser.parse_args()
450453
raw = util.load_image_data(args.input, ndim=3, key=args.key)
@@ -458,5 +461,6 @@ def main():
458461
warnings.warn("You have not passed an embedding_path. Restarting the annotator may take a long time.")
459462

460463
annotator_tracking(
461-
raw, embedding_path=args.embedding_path, show_embeddings=args.show_embeddings, tracking_result=tracking_result
464+
raw, embedding_path=args.embedding_path, show_embeddings=args.show_embeddings,
465+
tracking_result=tracking_result, model_type=args.model_type,
462466
)

micro_sam/util.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import os
23
from shutil import copyfileobj
34

@@ -28,9 +29,14 @@
2829
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
2930
}
3031
CHECKPOINT_FOLDER = os.environ.get("SAM_MODELS", os.path.expanduser("~/.sam_models"))
32+
CHECKSUMS = {
33+
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
34+
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
35+
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912"
36+
}
3137

3238

33-
def _download(url, path):
39+
def _download(url, path, model_type):
3440
with requests.get(url, stream=True, verify=True) as r:
3541
if r.status_code != 200:
3642
r.raise_for_status()
@@ -42,6 +48,20 @@ def _download(url, path):
4248
with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
4349
copyfileobj(r_raw, f)
4450

51+
# validate the checksum
52+
expected_checksum = CHECKSUMS[model_type]
53+
if expected_checksum is None:
54+
return
55+
with open(path, "rb") as f:
56+
file_ = f.read()
57+
checksum = hashlib.sha256(file_).hexdigest()
58+
if checksum != expected_checksum:
59+
raise RuntimeError(
60+
"The checksum of the download does not match the expected checksum."
61+
f"Expected: {expected_checksum}, got: {checksum}"
62+
)
63+
print("Download successful and checksums agree.")
64+
4565

4666
def _get_checkpoint(model_type, checkpoint_path=None):
4767
if checkpoint_path is None:
@@ -52,7 +72,7 @@ def _get_checkpoint(model_type, checkpoint_path=None):
5272
# download the checkpoint if necessary
5373
if not os.path.exists(checkpoint_path):
5474
os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
55-
_download(checkpoint_url, checkpoint_path)
75+
_download(checkpoint_url, checkpoint_path, model_type)
5676
elif not os.path.exists(checkpoint_path):
5777
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")
5878

0 commit comments

Comments
 (0)