Skip to content

Commit 02b9d10

Browse files
Merge pull request #157 from computational-cell-analytics/fix-amg-state-serialization
Serialize amg state to the cpu so that it can be loaded without a GPU
2 parents 884efb2 + a7eca9b commit 02b9d10

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

micro_sam/instance_segmentation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def get_state(self) -> Dict[str, Any]:
288288
"""
289289
if not self.is_initialized:
290290
raise RuntimeError("The state has not been computed yet. Call initialize first.")
291+
291292
return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
292293

293294
def set_state(self, state: Dict[str, Any]) -> None:

micro_sam/precompute_state.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Optional, Tuple, Union
1010

1111
import numpy as np
12+
import torch
1213
from segment_anything.predictor import SamPredictor
1314
from tqdm import tqdm
1415

@@ -50,9 +51,21 @@ def cache_amg_state(
5051

5152
if verbose:
5253
print("Precomputing the state for instance segmentation.")
54+
5355
amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
56+
amg_state = amg.get_state()
57+
58+
# put all state onto the cpu so that the state can be deserialized without a gpu
59+
new_crop_list = []
60+
for mask_data in amg_state["crop_list"]:
61+
for k, v in mask_data.items():
62+
if torch.is_tensor(v):
63+
mask_data[k] = v.cpu()
64+
new_crop_list.append(mask_data)
65+
amg_state["crop_list"] = new_crop_list
66+
5467
with open(save_path_amg, "wb") as f:
55-
pickle.dump(amg.get_state(), f)
68+
pickle.dump(amg_state, f)
5669

5770
return amg
5871

0 commit comments

Comments
 (0)