From f2fa21870595c9d5340826d065cc7959e972bb7b Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 11 Oct 2023 00:27:57 +0200 Subject: [PATCH 01/21] WIP Add bioimage.io model creation --- micro_sam/model_zoo.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 micro_sam/model_zoo.py diff --git a/micro_sam/model_zoo.py b/micro_sam/model_zoo.py new file mode 100644 index 000000000..23b9e735e --- /dev/null +++ b/micro_sam/model_zoo.py @@ -0,0 +1,39 @@ +import os +from glob import glob + +from bioimageio.core.build_spec import build_model + + +def _get_livecell_path(input_dir): + test_img_paths = glob(os.path.join(input_dir, "images", "livecell_test_images", "*")) + return test_img_paths[0] + + +def _get_modelzoo_yaml(): + input_path = _get_livecell_path("/scratch/usr/nimanwai/data/livecell") + + build_model( + weight_uri="~/.sam_models/vit_t_mobile_sam.pth", + test_inputs=[input_path], + test_outputs=["./results/"], + input_axes=["bcyx"], + output_axes=["bcyx"], + name="dinosaur", + description="lorem ipsum", + authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, + {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], + tags=["instance segmentation", "segment anything"], + license="LOREM IPSUM", # FIXME + documentation="README.md", # TODO - check out what to put here + cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", + "doi": "10.1101/2023.08.21.554208"}], + output_path="my_micro_sam.zip" + ) + + +def main(): + _get_modelzoo_yaml() + + +if __name__ == "__main__": + main() From 2073119444485cc6e393063521ef7e914e6d47a9 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 11 Oct 2023 09:54:16 +0200 Subject: [PATCH 02/21] Update model building script --- micro_sam/model_zoo.py | 44 +++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/micro_sam/model_zoo.py b/micro_sam/model_zoo.py index 23b9e735e..600a075d9 100644 --- a/micro_sam/model_zoo.py +++ b/micro_sam/model_zoo.py @@ -1,33 +1,51 @@ import os +import numpy as np from glob import glob +import imageio.v2 as imageio from bioimageio.core.build_spec import build_model -def _get_livecell_path(input_dir): - test_img_paths = glob(os.path.join(input_dir, "images", "livecell_test_images", "*")) - return test_img_paths[0] +def _get_livecell_npy_path(input_dir): + test_img_paths = sorted(glob(os.path.join(input_dir, "images", "livecell_test_images", "*"))) + input_image = imageio.imread(test_img_paths[0]) + save_image_path = "./test-livecell-image.npy" + np.save(save_image_path, input_image) + + # TODO: probably we need the prompt inputs here as well + + # TODO: get output paths + # outputs: model(inputs) -> outputs: converted to numpy + save_output_path = ".npy" + + return [save_image_path], [save_output_path] + + +def _get_documentation(doc_path): + with open(doc_path, "w") as f: + f.write("# Segment Anything for Microscopy\n") + f.write("Lorem Ipsum\n") + return doc_path def _get_modelzoo_yaml(): - input_path = _get_livecell_path("/scratch/usr/nimanwai/data/livecell") + input_list, output_list = _get_livecell_npy_path("/scratch/usr/nimanwai/data/livecell") build_model( weight_uri="~/.sam_models/vit_t_mobile_sam.pth", - test_inputs=[input_path], - test_outputs=["./results/"], + test_inputs=input_list, # type: ignore + test_outputs=output_list, # type: ignore input_axes=["bcyx"], output_axes=["bcyx"], name="dinosaur", - description="lorem ipsum", + description="Finetuned Segment Anything models for Microscopy", authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], - tags=["instance segmentation", "segment anything"], - license="LOREM IPSUM", # FIXME - documentation="README.md", # TODO - check out what to put here - cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", - "doi": "10.1101/2023.08.21.554208"}], - output_path="my_micro_sam.zip" + tags=["instance-segmentation", "segment-anything"], + license="CC-BY-4.0", # TODO: check with Constantin + documentation=_get_documentation("./doc.md"), + cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], + output_path="./modelzoo/my_micro_sam.zip" ) From fbfde8edc83708967ad52f480c42925db881f552 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 11 Oct 2023 14:29:21 +0200 Subject: [PATCH 03/21] Update model predictor adaptor for bioimage models --- examples/model_zoo/get_bioimage_modelzoo.py | 18 ++++ micro_sam/model_zoo.py | 101 ++++++++++++++++---- micro_sam/predictor_adaptor.py | 38 ++++++++ 3 files changed, 138 insertions(+), 19 deletions(-) create mode 100644 examples/model_zoo/get_bioimage_modelzoo.py create mode 100644 micro_sam/predictor_adaptor.py diff --git a/examples/model_zoo/get_bioimage_modelzoo.py b/examples/model_zoo/get_bioimage_modelzoo.py new file mode 100644 index 000000000..bcf91e1d2 --- /dev/null +++ b/examples/model_zoo/get_bioimage_modelzoo.py @@ -0,0 +1,18 @@ +from micro_sam import model_zoo + + +def main(): + parser = model_zoo._get_modelzoo_parser() + args = parser.parse_args() + + model_zoo.get_modelzoo_yaml( + image_path=args.input_path, + box_prompts=None, + model_type=args.model_type, + output_path=args.output_path, + doc_path=args.doc_path + ) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/model_zoo.py b/micro_sam/model_zoo.py index 600a075d9..2287575b9 100644 --- a/micro_sam/model_zoo.py +++ b/micro_sam/model_zoo.py @@ -1,24 +1,63 @@ import os +import argparse import numpy as np from glob import glob +from typing import List + import imageio.v2 as imageio +import torch + +from micro_sam import util + +from .predictor_adaptor import PredictorAdaptor +from .prompt_based_segmentation import _compute_box_from_mask + from bioimageio.core.build_spec import build_model -def _get_livecell_npy_path(input_dir): +def _get_model(image, model_type): + "Returns the model and predictor while initializing with the model checkpoints" + predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True) # type: ignore + image_embeddings = util.precompute_image_embeddings(predictor, image) + util.set_precomputed(predictor, image_embeddings) + return predictor, sam_model + + +def _get_livecell_npy_paths( + input_dir: str, + model_type: str +): test_img_paths = sorted(glob(os.path.join(input_dir, "images", "livecell_test_images", "*"))) - input_image = imageio.imread(test_img_paths[0]) + chosen_input = test_img_paths[0] + + input_image = imageio.imread(chosen_input) + + fname = os.path.split(chosen_input)[-1] + cell_type = fname.split("_")[0] + label_image = imageio.imread(os.path.join(input_dir, "annotations", "livecell_test_images", cell_type, fname)) + save_image_path = "./test-livecell-image.npy" np.save(save_image_path, input_image) - # TODO: probably we need the prompt inputs here as well + predictor, sam_model = _get_model(input_image, model_type) + get_instance_segmentation = PredictorAdaptor(sam_model=sam_model) - # TODO: get output paths - # outputs: model(inputs) -> outputs: converted to numpy - save_output_path = ".npy" + box_prompts = _compute_box_from_mask(label_image) + save_box_prompt_path = "./test-box-prompts.npy" + np.save(save_box_prompt_path, box_prompts) - return [save_image_path], [save_output_path] + instances = get_instance_segmentation( + input_image=torch.from_numpy(input_image)[None, None], + predictor=predictor, + image_embeddings=None, + box_prompts=torch.from_numpy(box_prompts)[None] + ) + + save_output_path = "./test-livecell-output.npy" + np.save(save_output_path, instances.squeeze().numpy()) + + return [save_image_path, save_box_prompt_path], [save_output_path] def _get_documentation(doc_path): @@ -28,11 +67,30 @@ def _get_documentation(doc_path): return doc_path -def _get_modelzoo_yaml(): - input_list, output_list = _get_livecell_npy_path("/scratch/usr/nimanwai/data/livecell") +def _get_sam_checkpoints(model_type): + checkpoint = util._get_checkpoint(model_type, None) + print(f"{model_type} is available at {checkpoint}") + return checkpoint + + +def get_modelzoo_yaml( + image_path: str, + box_prompts: List[int], + model_type: str, + output_path: str, + doc_path: str +): + # load the model and the image and prompts + # feed prompts and image to the model to get the output + # save the numpy file for the output to get the expected data + + input_list, output_list = _get_livecell_npy_paths(input_dir=image_path, model_type=model_type) + _checkpoint = _get_sam_checkpoints(model_type) + + breakpoint() build_model( - weight_uri="~/.sam_models/vit_t_mobile_sam.pth", + weight_uri=_checkpoint, test_inputs=input_list, # type: ignore test_outputs=output_list, # type: ignore input_axes=["bcyx"], @@ -42,16 +100,21 @@ def _get_modelzoo_yaml(): authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], tags=["instance-segmentation", "segment-anything"], - license="CC-BY-4.0", # TODO: check with Constantin - documentation=_get_documentation("./doc.md"), + license="CC-BY-4.0", + documentation=_get_documentation(doc_path), cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], - output_path="./modelzoo/my_micro_sam.zip" + output_path=output_path ) -def main(): - _get_modelzoo_yaml() - - -if __name__ == "__main__": - main() +def _get_modelzoo_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_path", type=str, + help="Path to the raw inputs' directory") + parser.add_argument("-m", "--model_type", type=str, default="vit_b", + help="Name of the model to get the SAM checkpoints") + parser.add_argument("-o", "--output_path", type=str, default="./models/sam.zip", + help="Path to the output bioimage modelzoo-format SAM model") + parser.add_argument("-d", "--doc_path", type=str, default="./documentation.md", + help="Path to the documentation") + return parser diff --git a/micro_sam/predictor_adaptor.py b/micro_sam/predictor_adaptor.py new file mode 100644 index 000000000..c69313698 --- /dev/null +++ b/micro_sam/predictor_adaptor.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch + +from segment_anything.predictor import SamPredictor + + +class PredictorAdaptor(SamPredictor): + """Wrapper around the SamPredictor to be used by BioImage.IO model format. + + This model supports the same functionality as SamPredictor and can provide mask segmentations + from box, point or mask input prompts. + """ + def __call__( + self, + input_image: torch.Tensor, + image_embeddings: Optional[torch.Tensor] = None, + box_prompts: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed + pass # do nothing + elif self.is_image_set and image_embeddings is not None: + raise NotImplementedError # TODO: replace the image embeedings + elif image_embeddings is not None: + pass # TODO set the image embeddings + # self.features = image_embeddings + elif not self.is_image_set: + self.set_torch_image(input_image) # compute the image embeddings + + instance_segmentation, _, _ = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=box_prompts, + multimask_output=False + ) + # TODO get the image embeddings via image_embeddings = self.features + # and return them + return instance_segmentation From c62fe939276e7a9d34b528e0d39f95bbaff47663 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 11 Oct 2023 14:31:34 +0200 Subject: [PATCH 04/21] Refactor modelzoo functionality into submodule --- micro_sam/modelzoo/__init__.py | 0 micro_sam/{model_zoo.py => modelzoo/model_export.py} | 0 micro_sam/{ => modelzoo}/predictor_adaptor.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 micro_sam/modelzoo/__init__.py rename micro_sam/{model_zoo.py => modelzoo/model_export.py} (100%) rename micro_sam/{ => modelzoo}/predictor_adaptor.py (100%) diff --git a/micro_sam/modelzoo/__init__.py b/micro_sam/modelzoo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/micro_sam/model_zoo.py b/micro_sam/modelzoo/model_export.py similarity index 100% rename from micro_sam/model_zoo.py rename to micro_sam/modelzoo/model_export.py diff --git a/micro_sam/predictor_adaptor.py b/micro_sam/modelzoo/predictor_adaptor.py similarity index 100% rename from micro_sam/predictor_adaptor.py rename to micro_sam/modelzoo/predictor_adaptor.py From 615ffc64db9b8b163978c6027757109d34890158 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 11 Oct 2023 15:53:38 +0200 Subject: [PATCH 05/21] Add first working scripts for bioengine export --- examples/model_zoo/bioengine_export.py | 7 ++ micro_sam/modelzoo/image_encoder_export.py | 54 ++++++++++ micro_sam/modelzoo/onnx_export.py | 119 +++++++++++++++++++++ 3 files changed, 180 insertions(+) create mode 100644 examples/model_zoo/bioengine_export.py create mode 100644 micro_sam/modelzoo/image_encoder_export.py create mode 100644 micro_sam/modelzoo/onnx_export.py diff --git a/examples/model_zoo/bioengine_export.py b/examples/model_zoo/bioengine_export.py new file mode 100644 index 000000000..4d70ff508 --- /dev/null +++ b/examples/model_zoo/bioengine_export.py @@ -0,0 +1,7 @@ +# TODO combined script +from micro_sam.modelzoo.image_encoder_export import export_image_encoder +from micro_sam.modelzoo.onnx_export import export_onnx + +model_type = "vit_b" +export_image_encoder(model_type, "./test-export") +export_onnx(model_type, "./test-export", opset=12) diff --git a/micro_sam/modelzoo/image_encoder_export.py b/micro_sam/modelzoo/image_encoder_export.py new file mode 100644 index 000000000..b001e0a24 --- /dev/null +++ b/micro_sam/modelzoo/image_encoder_export.py @@ -0,0 +1,54 @@ +import os +import torch +from ..util import get_sam_model + + +ENCODER_CONFIG = """name: "sam-backbone" +backend: "pytorch" +platform: "pytorch_libtorch" + +max_batch_size : 1 +input [ + { + name: "input0__0" + data_type: TYPE_FP32 + dims: [3, -1, -1] + } +] +output [ + { + name: "output0__0" + data_type: TYPE_FP32 + dims: [256, 64, 64] + } +] + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +}""" + + +def export_image_encoder( + model_type, + output_root, + checkpoint_path=None, +): + output_folder = os.path.join(output_root, "sam-backbone") + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + encoder = predictor.model.image_encoder + + encoder.eval() + input_ = torch.rand(1, 3, 1024, 1024) + traced_model = torch.jit.trace(encoder, input_) + weight_path = os.path.join(weight_output_folder, "model.pt") + traced_model.save(weight_path) + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(ENCODER_CONFIG) diff --git a/micro_sam/modelzoo/onnx_export.py b/micro_sam/modelzoo/onnx_export.py new file mode 100644 index 000000000..5a8632e2f --- /dev/null +++ b/micro_sam/modelzoo/onnx_export.py @@ -0,0 +1,119 @@ +import os +import warnings + +import torch +from segment_anything.utils.onnx import SamOnnxModel + +try: + import onnxruntime + onnxruntime_exists = True +except ImportError: + onnxruntime_exists = False + +from ..util import get_sam_model + + +# TODO check if this is still correct +DECODER_CONFIG = """name: "sam-decoder" +backend: "onnxruntime" +platform: "onnxruntime_onnx" + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +} + +instance_group { + count: 1 + kind: KIND_CPU +}""" + + +def to_numpy(tensor): + return tensor.cpu().numpy() + + +# ONNX export script adapted from +# https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py +def export_onnx( + model_type, + output_root, + opset: int, + checkpoint_path=None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +): + output_folder = os.path.join(output_root, "sam-decoder") + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + _, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True) + weight_path = os.path.join(weight_output_folder, "model.onnx") + + onnx_model = SamOnnxModel( + model=sam, + return_single_mask=return_single_mask, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) + + if gelu_approximate: + for n, m in onnx_model.named_modules: + if isinstance(m, torch.nn.GELU): + m.approximate = "tanh" + + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + + mask_input_size = [4 * x for x in embed_size] + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + _ = onnx_model(**dummy_inputs) + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(weight_path, "wb") as f: + print(f"Exporting onnx model to {weight_path}...") + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + if onnxruntime_exists: + ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} + # set cpu provider default + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession(weight_path, providers=providers) + _ = ort_session.run(None, ort_inputs) + print("Model has successfully been run with ONNXRuntime.") + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(DECODER_CONFIG) From 09f5f76d74b037a13115bef2dc6e708686c60643 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 11 Oct 2023 16:39:44 +0200 Subject: [PATCH 06/21] Add input prompt transofrms to adaptor --- examples/model_zoo/get_bioimage_modelzoo.py | 8 +++---- micro_sam/modelzoo/model_export.py | 24 ++++++++++----------- micro_sam/modelzoo/predictor_adaptor.py | 15 +++++++++++-- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/examples/model_zoo/get_bioimage_modelzoo.py b/examples/model_zoo/get_bioimage_modelzoo.py index bcf91e1d2..e54317b9b 100644 --- a/examples/model_zoo/get_bioimage_modelzoo.py +++ b/examples/model_zoo/get_bioimage_modelzoo.py @@ -1,13 +1,13 @@ -from micro_sam import model_zoo +from micro_sam.modelzoo import model_export def main(): - parser = model_zoo._get_modelzoo_parser() + parser = model_export._get_modelzoo_parser() args = parser.parse_args() - model_zoo.get_modelzoo_yaml( + model_export.get_modelzoo_yaml( image_path=args.input_path, - box_prompts=None, + box_prompts_path=args.boxes_path, model_type=args.model_type, output_path=args.output_path, doc_path=args.doc_path diff --git a/micro_sam/modelzoo/model_export.py b/micro_sam/modelzoo/model_export.py index 2287575b9..7447724f9 100644 --- a/micro_sam/modelzoo/model_export.py +++ b/micro_sam/modelzoo/model_export.py @@ -2,7 +2,6 @@ import argparse import numpy as np from glob import glob -from typing import List import imageio.v2 as imageio @@ -11,7 +10,7 @@ from micro_sam import util from .predictor_adaptor import PredictorAdaptor -from .prompt_based_segmentation import _compute_box_from_mask +from ..prompt_based_segmentation import _compute_box_from_mask from bioimageio.core.build_spec import build_model @@ -40,16 +39,17 @@ def _get_livecell_npy_paths( save_image_path = "./test-livecell-image.npy" np.save(save_image_path, input_image) - predictor, sam_model = _get_model(input_image, model_type) + _, sam_model = _get_model(input_image, model_type) get_instance_segmentation = PredictorAdaptor(sam_model=sam_model) box_prompts = _compute_box_from_mask(label_image) save_box_prompt_path = "./test-box-prompts.npy" np.save(save_box_prompt_path, box_prompts) + input_image = util._to_image(input_image).transpose(2, 0, 1) + instances = get_instance_segmentation( - input_image=torch.from_numpy(input_image)[None, None], - predictor=predictor, + input_image=torch.from_numpy(input_image)[None], image_embeddings=None, box_prompts=torch.from_numpy(box_prompts)[None] ) @@ -57,7 +57,7 @@ def _get_livecell_npy_paths( save_output_path = "./test-livecell-output.npy" np.save(save_output_path, instances.squeeze().numpy()) - return [save_image_path, save_box_prompt_path], [save_output_path] + return save_image_path, save_output_path def _get_documentation(doc_path): @@ -75,18 +75,16 @@ def _get_sam_checkpoints(model_type): def get_modelzoo_yaml( image_path: str, - box_prompts: List[int], + box_prompts_path: str, model_type: str, output_path: str, doc_path: str ): - # load the model and the image and prompts - # feed prompts and image to the model to get the output - # save the numpy file for the output to get the expected data - - input_list, output_list = _get_livecell_npy_paths(input_dir=image_path, model_type=model_type) + input_path, output_path = _get_livecell_npy_paths(input_dir=image_path, model_type=model_type) _checkpoint = _get_sam_checkpoints(model_type) + input_list = [input_path, box_prompts_path] + output_list = [output_path] breakpoint() build_model( @@ -117,4 +115,6 @@ def _get_modelzoo_parser(): help="Path to the output bioimage modelzoo-format SAM model") parser.add_argument("-d", "--doc_path", type=str, default="./documentation.md", help="Path to the documentation") + parser.add_argument("--boxes_path", type=str, default=None, + help="Path to the saved box prompts") return parser diff --git a/micro_sam/modelzoo/predictor_adaptor.py b/micro_sam/modelzoo/predictor_adaptor.py index c69313698..8fc4a058a 100644 --- a/micro_sam/modelzoo/predictor_adaptor.py +++ b/micro_sam/modelzoo/predictor_adaptor.py @@ -17,6 +17,11 @@ def __call__( image_embeddings: Optional[torch.Tensor] = None, box_prompts: Optional[torch.Tensor] = None ) -> torch.Tensor: + """Expected inputs: + - input_image: torch inputs of dimensions B x C x H x W + - image_embeddings: precomputed image embeddings + - box_prompts: box prompts of dimensions C x 4 + """ if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed pass # do nothing elif self.is_image_set and image_embeddings is not None: @@ -25,14 +30,20 @@ def __call__( pass # TODO set the image embeddings # self.features = image_embeddings elif not self.is_image_set: - self.set_torch_image(input_image) # compute the image embeddings + image = self.transform.apply_image_torch(input_image) + self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:]) # compute the image embeddings + + boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:]) # type: ignore instance_segmentation, _, _ = self.predict_torch( point_coords=None, point_labels=None, - boxes=box_prompts, + boxes=boxes, multimask_output=False ) + + assert instance_segmentation.shape[2:] == input_image.shape[2:], f"{instance_segmentation.shape[2:]} is not as expected ({input_image.shape[2:]})" + # TODO get the image embeddings via image_embeddings = self.features # and return them return instance_segmentation From b869a33e43767a8d6d127d63974b183346d04190 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 11 Oct 2023 18:09:55 +0200 Subject: [PATCH 07/21] Update numpy input saving --- micro_sam/modelzoo/model_export.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/micro_sam/modelzoo/model_export.py b/micro_sam/modelzoo/model_export.py index 7447724f9..088b81126 100644 --- a/micro_sam/modelzoo/model_export.py +++ b/micro_sam/modelzoo/model_export.py @@ -37,7 +37,7 @@ def _get_livecell_npy_paths( label_image = imageio.imread(os.path.join(input_dir, "annotations", "livecell_test_images", cell_type, fname)) save_image_path = "./test-livecell-image.npy" - np.save(save_image_path, input_image) + np.save(save_image_path, input_image[None, None]) _, sam_model = _get_model(input_image, model_type) get_instance_segmentation = PredictorAdaptor(sam_model=sam_model) @@ -55,7 +55,7 @@ def _get_livecell_npy_paths( ) save_output_path = "./test-livecell-output.npy" - np.save(save_output_path, instances.squeeze().numpy()) + np.save(save_output_path, instances.numpy()) return save_image_path, save_output_path @@ -63,7 +63,8 @@ def _get_livecell_npy_paths( def _get_documentation(doc_path): with open(doc_path, "w") as f: f.write("# Segment Anything for Microscopy\n") - f.write("Lorem Ipsum\n") + f.write("We extend Segment Anything, a vision foundation model for image segmentation ") + f.write("by training specialized models for microscopy data.\n") return doc_path @@ -85,7 +86,8 @@ def get_modelzoo_yaml( input_list = [input_path, box_prompts_path] output_list = [output_path] - breakpoint() + + architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") build_model( weight_uri=_checkpoint, @@ -101,7 +103,8 @@ def get_modelzoo_yaml( license="CC-BY-4.0", documentation=_get_documentation(doc_path), cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], - output_path=output_path + output_path=output_path, + architecture=architecture_path ) @@ -111,7 +114,7 @@ def _get_modelzoo_parser(): help="Path to the raw inputs' directory") parser.add_argument("-m", "--model_type", type=str, default="vit_b", help="Name of the model to get the SAM checkpoints") - parser.add_argument("-o", "--output_path", type=str, default="./models/sam.zip", + parser.add_argument("-o", "--output_path", type=str, default="./models/micro_sam.zip", help="Path to the output bioimage modelzoo-format SAM model") parser.add_argument("-d", "--doc_path", type=str, default="./documentation.md", help="Path to the documentation") From 96dc3f6f671db97e1d8bd9a0ebb2f9186874dd3c Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 11 Oct 2023 18:12:26 +0200 Subject: [PATCH 08/21] Update bioengine export script --- examples/model_zoo/bioengine_export.py | 8 +- .../{onnx_export.py => bioengine_export.py} | 98 +++++++++++++++++-- micro_sam/modelzoo/image_encoder_export.py | 54 ---------- 3 files changed, 94 insertions(+), 66 deletions(-) rename micro_sam/modelzoo/{onnx_export.py => bioengine_export.py} (58%) delete mode 100644 micro_sam/modelzoo/image_encoder_export.py diff --git a/examples/model_zoo/bioengine_export.py b/examples/model_zoo/bioengine_export.py index 4d70ff508..0b3f9762e 100644 --- a/examples/model_zoo/bioengine_export.py +++ b/examples/model_zoo/bioengine_export.py @@ -1,7 +1,3 @@ -# TODO combined script -from micro_sam.modelzoo.image_encoder_export import export_image_encoder -from micro_sam.modelzoo.onnx_export import export_onnx +from micro_sam.modelzoo.bioengine_export import export_bioengine_model -model_type = "vit_b" -export_image_encoder(model_type, "./test-export") -export_onnx(model_type, "./test-export", opset=12) +export_bioengine_model("vit_b", "test-export", opset=12) diff --git a/micro_sam/modelzoo/onnx_export.py b/micro_sam/modelzoo/bioengine_export.py similarity index 58% rename from micro_sam/modelzoo/onnx_export.py rename to micro_sam/modelzoo/bioengine_export.py index 5a8632e2f..63fbd102a 100644 --- a/micro_sam/modelzoo/onnx_export.py +++ b/micro_sam/modelzoo/bioengine_export.py @@ -1,5 +1,6 @@ import os import warnings +from typing import Optional, Union import torch from segment_anything.utils.onnx import SamOnnxModel @@ -13,8 +14,35 @@ from ..util import get_sam_model -# TODO check if this is still correct -DECODER_CONFIG = """name: "sam-decoder" +ENCODER_CONFIG = """name: "%s" +backend: "pytorch" +platform: "pytorch_libtorch" + +max_batch_size : 1 +input [ + { + name: "input0__0" + data_type: TYPE_FP32 + dims: [3, -1, -1] + } +] +output [ + { + name: "output0__0" + data_type: TYPE_FP32 + dims: [256, 64, 64] + } +] + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +}""" + + +DECODER_CONFIG = """name: "%s" backend: "onnxruntime" platform: "onnxruntime_onnx" @@ -35,19 +63,52 @@ def to_numpy(tensor): return tensor.cpu().numpy() +def export_image_encoder( + model_type: str, + output_root: Union[str, os.PathLike], + checkpoint_path: Optional[str] = None, + export_name: Optional[str] = None, +): + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-encoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + encoder = predictor.model.image_encoder + + encoder.eval() + input_ = torch.rand(1, 3, 1024, 1024) + traced_model = torch.jit.trace(encoder, input_) + weight_path = os.path.join(weight_output_folder, "model.pt") + traced_model.save(weight_path) + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(ENCODER_CONFIG % name) + + # ONNX export script adapted from # https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py -def export_onnx( +def export_onnx_model( model_type, output_root, opset: int, - checkpoint_path=None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_single_mask: bool = True, gelu_approximate: bool = False, use_stability_score: bool = False, return_extra_metrics: bool = False, + export_name: Optional[str] = None, ): - output_folder = os.path.join(output_root, "sam-decoder") + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-decoder" + + output_folder = os.path.join(output_root, name) weight_output_folder = os.path.join(output_folder, "1") os.makedirs(weight_output_folder, exist_ok=True) @@ -116,4 +177,29 @@ def export_onnx( config_output_path = os.path.join(output_folder, "config.pbtxt") with open(config_output_path, "w") as f: - f.write(DECODER_CONFIG) + f.write(DECODER_CONFIG % name) + + +def export_bioengine_model( + model_type, + output_root, + opset: int, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + export_name: Optional[str] = None, +): + export_image_encoder(model_type, output_root, checkpoint_path, export_name) + export_onnx_model( + model_type=model_type, + output_root=output_root, + opset=opset, + checkpoint_path=checkpoint_path, + return_single_mask=return_single_mask, + gelu_approximate=gelu_approximate, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + export_name=export_name + ) diff --git a/micro_sam/modelzoo/image_encoder_export.py b/micro_sam/modelzoo/image_encoder_export.py deleted file mode 100644 index b001e0a24..000000000 --- a/micro_sam/modelzoo/image_encoder_export.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -import torch -from ..util import get_sam_model - - -ENCODER_CONFIG = """name: "sam-backbone" -backend: "pytorch" -platform: "pytorch_libtorch" - -max_batch_size : 1 -input [ - { - name: "input0__0" - data_type: TYPE_FP32 - dims: [3, -1, -1] - } -] -output [ - { - name: "output0__0" - data_type: TYPE_FP32 - dims: [256, 64, 64] - } -] - -parameters: { - key: "INFERENCE_MODE" - value: { - string_value: "true" - } -}""" - - -def export_image_encoder( - model_type, - output_root, - checkpoint_path=None, -): - output_folder = os.path.join(output_root, "sam-backbone") - weight_output_folder = os.path.join(output_folder, "1") - os.makedirs(weight_output_folder, exist_ok=True) - - predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) - encoder = predictor.model.image_encoder - - encoder.eval() - input_ = torch.rand(1, 3, 1024, 1024) - traced_model = torch.jit.trace(encoder, input_) - weight_path = os.path.join(weight_output_folder, "model.pt") - traced_model.save(weight_path) - - config_output_path = os.path.join(output_folder, "config.pbtxt") - with open(config_output_path, "w") as f: - f.write(ENCODER_CONFIG) From 7d6bfcab2a9f7793e457f834c6582533e1c4abc2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 12 Oct 2023 10:42:09 +0200 Subject: [PATCH 09/21] Refactor modelzoo export --- examples/model_zoo/get_bioimage_modelzoo.py | 18 --- .../export_model_for_bioengine.py} | 0 .../modelzoo/export_model_for_bioimageio.py | 20 +++ examples/modelzoo/imjoy_test.py | 42 ++++++ micro_sam/modelzoo/__init__.py | 2 + micro_sam/modelzoo/bioimageio_export.py | 100 ++++++++++++++ micro_sam/modelzoo/model_export.py | 123 ------------------ 7 files changed, 164 insertions(+), 141 deletions(-) delete mode 100644 examples/model_zoo/get_bioimage_modelzoo.py rename examples/{model_zoo/bioengine_export.py => modelzoo/export_model_for_bioengine.py} (100%) create mode 100644 examples/modelzoo/export_model_for_bioimageio.py create mode 100644 examples/modelzoo/imjoy_test.py create mode 100644 micro_sam/modelzoo/bioimageio_export.py delete mode 100644 micro_sam/modelzoo/model_export.py diff --git a/examples/model_zoo/get_bioimage_modelzoo.py b/examples/model_zoo/get_bioimage_modelzoo.py deleted file mode 100644 index e54317b9b..000000000 --- a/examples/model_zoo/get_bioimage_modelzoo.py +++ /dev/null @@ -1,18 +0,0 @@ -from micro_sam.modelzoo import model_export - - -def main(): - parser = model_export._get_modelzoo_parser() - args = parser.parse_args() - - model_export.get_modelzoo_yaml( - image_path=args.input_path, - box_prompts_path=args.boxes_path, - model_type=args.model_type, - output_path=args.output_path, - doc_path=args.doc_path - ) - - -if __name__ == "__main__": - main() diff --git a/examples/model_zoo/bioengine_export.py b/examples/modelzoo/export_model_for_bioengine.py similarity index 100% rename from examples/model_zoo/bioengine_export.py rename to examples/modelzoo/export_model_for_bioengine.py diff --git a/examples/modelzoo/export_model_for_bioimageio.py b/examples/modelzoo/export_model_for_bioimageio.py new file mode 100644 index 000000000..cd4214533 --- /dev/null +++ b/examples/modelzoo/export_model_for_bioimageio.py @@ -0,0 +1,20 @@ +from micro_sam.modelzoo import export_bioimageio_model +from micro_sam.sample_data import synthetic_data + + +def export_model_with_synthetic_data(): + image, labels = synthetic_data(shape=(1024, 1024)) + + export_bioimageio_model( + image, labels, + model_type="vit_b", model_name="sam-test-vit-b", + output_path="./test_export.zip", + ) + + +def main(): + export_model_with_synthetic_data() + + +if __name__ == "__main__": + main() diff --git a/examples/modelzoo/imjoy_test.py b/examples/modelzoo/imjoy_test.py new file mode 100644 index 000000000..77e78b741 --- /dev/null +++ b/examples/modelzoo/imjoy_test.py @@ -0,0 +1,42 @@ +import numpy as np +from imjoy_rpc.hypha import connect_to_server +import time + +image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( + "float32" +) + +# SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io" +SERVER_URL = "https://hypha.bioimage.io" + + +async def test_backbone(triton): + config = await triton.get_config(model_name="micro-sam-vit-b-backbone") + print(config) + + image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( + "float32" + ) + + start_time = time.time() + result = await triton.execute( + inputs=[image], + model_name="micro-sam-vit-b-backbone", + ) + print("Backbone", result) + embedding = result['output0__0'] + print("Time taken: ", time.time() - start_time) + print("Test passed", embedding.shape) + + +async def run(): + server = await connect_to_server( + {"name": "test client", "server_url": SERVER_URL, "method_timeout": 100} + ) + triton = await server.get_service("triton-client") + await test_backbone(triton) + + +if __name__ == "__main__": + import asyncio + asyncio.run(run()) diff --git a/micro_sam/modelzoo/__init__.py b/micro_sam/modelzoo/__init__.py index e69de29bb..a66dd2da9 100644 --- a/micro_sam/modelzoo/__init__.py +++ b/micro_sam/modelzoo/__init__.py @@ -0,0 +1,2 @@ +from .bioimageio_export import export_bioimageio_model +from .bioengine_export import export_bioengine_model diff --git a/micro_sam/modelzoo/bioimageio_export.py b/micro_sam/modelzoo/bioimageio_export.py new file mode 100644 index 000000000..d03a366e4 --- /dev/null +++ b/micro_sam/modelzoo/bioimageio_export.py @@ -0,0 +1,100 @@ +import os +import numpy as np +from typing import Optional, Union + +import torch + +from bioimageio.core.build_spec import build_model + +from .. import util +from ..prompt_generators import PointAndBoxPromptGenerator +from .predictor_adaptor import PredictorAdaptor + + +def _get_model(image, model_type, checkpoint_path): + "Returns the model and predictor while initializing with the model checkpoints" + predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True) # type: ignore + image_embeddings = util.precompute_image_embeddings(predictor, image) + util.set_precomputed(predictor, image_embeddings) + return predictor, sam_model + + +# TODO use tempfile +def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path): + + # For now we just generate a single box prompt here, but we could also generate more input prompts. + generator = PointAndBoxPromptGenerator(0, 0, 4, False, True) + centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) + masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1]) + _, _, box_prompts, _ = generator(masks, [bounding_boxes[1]]) + box_prompts = box_prompts.numpy() + + save_image_path = "./test-livecell-image.npy" + np.save(save_image_path, image[None, None]) + + _, sam_model = _get_model(image, model_type, checkpoint_path) + predictor = PredictorAdaptor(sam_model=sam_model) + + save_box_prompt_path = "./test-box-prompts.npy" + np.save(save_box_prompt_path, box_prompts) + + input_ = util._to_image(image).transpose(2, 0, 1) + + # TODO embeddings are also expected output + instances = predictor( + input_image=torch.from_numpy(input_)[None], + image_embeddings=None, + box_prompts=torch.from_numpy(box_prompts)[None] + ) + + save_output_path = "./test-livecell-output.npy" + np.save(save_output_path, instances.numpy()) + + return [save_image_path, save_box_prompt_path], [save_output_path] + + +def _get_documentation(doc_path): + with open(doc_path, "w") as f: + f.write("# Segment Anything for Microscopy\n") + f.write("We extend Segment Anything, a vision foundation model for image segmentation ") + f.write("by training specialized models for microscopy data.\n") + return doc_path + + +def export_bioimageio_model( + image: np.ndarray, + label_image: np.ndarray, + model_type: str, + model_name: str, + output_path: Union[str, os.PathLike], + doc_path: Optional[Union[str, os.PathLike]] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, +): + input_paths, result_paths = _create_test_inputs_and_outputs( + image, label_image, model_type, checkpoint_path + ) + checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path) + + architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") + + if doc_path is None: + doc_path = "./doc.md" + _get_documentation(doc_path) + + build_model( + weight_uri=checkpoint, + test_inputs=input_paths, # type: ignore + test_outputs=result_paths, # type: ignore + input_axes=["bcyx"], + output_axes=["bcyx"], + name=model_name, + description="Finetuned Segment Anything models for Microscopy", + authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, + {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], + tags=["instance-segmentation", "segment-anything"], + license="CC-BY-4.0", + documentation=doc_path, + cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], + output_path=output_path, + architecture=architecture_path + ) diff --git a/micro_sam/modelzoo/model_export.py b/micro_sam/modelzoo/model_export.py deleted file mode 100644 index 088b81126..000000000 --- a/micro_sam/modelzoo/model_export.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import argparse -import numpy as np -from glob import glob - -import imageio.v2 as imageio - -import torch - -from micro_sam import util - -from .predictor_adaptor import PredictorAdaptor -from ..prompt_based_segmentation import _compute_box_from_mask - -from bioimageio.core.build_spec import build_model - - -def _get_model(image, model_type): - "Returns the model and predictor while initializing with the model checkpoints" - predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True) # type: ignore - image_embeddings = util.precompute_image_embeddings(predictor, image) - util.set_precomputed(predictor, image_embeddings) - return predictor, sam_model - - -def _get_livecell_npy_paths( - input_dir: str, - model_type: str -): - test_img_paths = sorted(glob(os.path.join(input_dir, "images", "livecell_test_images", "*"))) - chosen_input = test_img_paths[0] - - input_image = imageio.imread(chosen_input) - - fname = os.path.split(chosen_input)[-1] - cell_type = fname.split("_")[0] - label_image = imageio.imread(os.path.join(input_dir, "annotations", "livecell_test_images", cell_type, fname)) - - save_image_path = "./test-livecell-image.npy" - np.save(save_image_path, input_image[None, None]) - - _, sam_model = _get_model(input_image, model_type) - get_instance_segmentation = PredictorAdaptor(sam_model=sam_model) - - box_prompts = _compute_box_from_mask(label_image) - save_box_prompt_path = "./test-box-prompts.npy" - np.save(save_box_prompt_path, box_prompts) - - input_image = util._to_image(input_image).transpose(2, 0, 1) - - instances = get_instance_segmentation( - input_image=torch.from_numpy(input_image)[None], - image_embeddings=None, - box_prompts=torch.from_numpy(box_prompts)[None] - ) - - save_output_path = "./test-livecell-output.npy" - np.save(save_output_path, instances.numpy()) - - return save_image_path, save_output_path - - -def _get_documentation(doc_path): - with open(doc_path, "w") as f: - f.write("# Segment Anything for Microscopy\n") - f.write("We extend Segment Anything, a vision foundation model for image segmentation ") - f.write("by training specialized models for microscopy data.\n") - return doc_path - - -def _get_sam_checkpoints(model_type): - checkpoint = util._get_checkpoint(model_type, None) - print(f"{model_type} is available at {checkpoint}") - return checkpoint - - -def get_modelzoo_yaml( - image_path: str, - box_prompts_path: str, - model_type: str, - output_path: str, - doc_path: str -): - input_path, output_path = _get_livecell_npy_paths(input_dir=image_path, model_type=model_type) - _checkpoint = _get_sam_checkpoints(model_type) - - input_list = [input_path, box_prompts_path] - output_list = [output_path] - - architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") - - build_model( - weight_uri=_checkpoint, - test_inputs=input_list, # type: ignore - test_outputs=output_list, # type: ignore - input_axes=["bcyx"], - output_axes=["bcyx"], - name="dinosaur", - description="Finetuned Segment Anything models for Microscopy", - authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, - {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], - tags=["instance-segmentation", "segment-anything"], - license="CC-BY-4.0", - documentation=_get_documentation(doc_path), - cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], - output_path=output_path, - architecture=architecture_path - ) - - -def _get_modelzoo_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input_path", type=str, - help="Path to the raw inputs' directory") - parser.add_argument("-m", "--model_type", type=str, default="vit_b", - help="Name of the model to get the SAM checkpoints") - parser.add_argument("-o", "--output_path", type=str, default="./models/micro_sam.zip", - help="Path to the output bioimage modelzoo-format SAM model") - parser.add_argument("-d", "--doc_path", type=str, default="./documentation.md", - help="Path to the documentation") - parser.add_argument("--boxes_path", type=str, default=None, - help="Path to the saved box prompts") - return parser From b470257c790ca3547f39b00606202c67a0afefde Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 12 Oct 2023 13:23:06 +0200 Subject: [PATCH 10/21] Add tempfile for model conversion inputs --- micro_sam/modelzoo/bioimageio_export.py | 79 ++++++++++++++----------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/micro_sam/modelzoo/bioimageio_export.py b/micro_sam/modelzoo/bioimageio_export.py index d03a366e4..6d4ae3927 100644 --- a/micro_sam/modelzoo/bioimageio_export.py +++ b/micro_sam/modelzoo/bioimageio_export.py @@ -1,4 +1,5 @@ import os +from tempfile import NamedTemporaryFile as tmp_file import numpy as np from typing import Optional, Union @@ -13,29 +14,30 @@ def _get_model(image, model_type, checkpoint_path): "Returns the model and predictor while initializing with the model checkpoints" - predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True) # type: ignore + predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True, + checkpoint_path=checkpoint_path) # type: ignore image_embeddings = util.precompute_image_embeddings(predictor, image) util.set_precomputed(predictor, image_embeddings) return predictor, sam_model -# TODO use tempfile -def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path): +def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, + tmp_input_path, tmp_boxes_path, tmp_output_path): # For now we just generate a single box prompt here, but we could also generate more input prompts. generator = PointAndBoxPromptGenerator(0, 0, 4, False, True) centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) - masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1]) + masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1]) # type: ignore _, _, box_prompts, _ = generator(masks, [bounding_boxes[1]]) box_prompts = box_prompts.numpy() - save_image_path = "./test-livecell-image.npy" + save_image_path = tmp_input_path.name np.save(save_image_path, image[None, None]) _, sam_model = _get_model(image, model_type, checkpoint_path) predictor = PredictorAdaptor(sam_model=sam_model) - save_box_prompt_path = "./test-box-prompts.npy" + save_box_prompt_path = tmp_boxes_path.name np.save(save_box_prompt_path, box_prompts) input_ = util._to_image(image).transpose(2, 0, 1) @@ -47,7 +49,7 @@ def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path): box_prompts=torch.from_numpy(box_prompts)[None] ) - save_output_path = "./test-livecell-output.npy" + save_output_path = tmp_output_path.name np.save(save_output_path, instances.numpy()) return [save_image_path, save_box_prompt_path], [save_output_path] @@ -70,31 +72,38 @@ def export_bioimageio_model( doc_path: Optional[Union[str, os.PathLike]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, ): - input_paths, result_paths = _create_test_inputs_and_outputs( - image, label_image, model_type, checkpoint_path - ) - checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path) - - architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") - - if doc_path is None: - doc_path = "./doc.md" - _get_documentation(doc_path) - - build_model( - weight_uri=checkpoint, - test_inputs=input_paths, # type: ignore - test_outputs=result_paths, # type: ignore - input_axes=["bcyx"], - output_axes=["bcyx"], - name=model_name, - description="Finetuned Segment Anything models for Microscopy", - authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, - {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], - tags=["instance-segmentation", "segment-anything"], - license="CC-BY-4.0", - documentation=doc_path, - cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], - output_path=output_path, - architecture=architecture_path - ) + with ( + tmp_file(suffix=".md") as tmp_doc_path, + tmp_file(suffix=".npy") as tmp_input_path, + tmp_file(suffix=".npy") as tmp_boxes_path, + tmp_file(suffix=".npy") as tmp_output_path + ): + input_paths, result_paths = _create_test_inputs_and_outputs( + image, label_image, model_type, checkpoint_path, tmp_input_path, tmp_boxes_path, tmp_output_path + ) + checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path) + + architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") + + if doc_path is None: + doc_path = tmp_doc_path.name + _get_documentation(doc_path) + + build_model( + weight_uri=checkpoint, # type: ignore + test_inputs=input_paths, + test_outputs=result_paths, + input_axes=["bcyx"], + output_axes=["bcyx"], + name=model_name, + description="Finetuned Segment Anything models for Microscopy", + authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, + {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], + tags=["instance-segmentation", "segment-anything"], + license="CC-BY-4.0", + documentation=doc_path, # type: ignore + cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", + "doi": "10.1101/2023.08.21.554208"}], + output_path=output_path, # type: ignore + architecture=architecture_path + ) From de6b2454d9c08657d1920682c442c46c73e1152d Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 12 Oct 2023 13:27:35 +0200 Subject: [PATCH 11/21] Add doc-strings to bioengine export functionality --- micro_sam/modelzoo/bioengine_export.py | 69 +++++++++++++++++++++----- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/micro_sam/modelzoo/bioengine_export.py b/micro_sam/modelzoo/bioengine_export.py index 63fbd102a..3201920b3 100644 --- a/micro_sam/modelzoo/bioengine_export.py +++ b/micro_sam/modelzoo/bioengine_export.py @@ -59,16 +59,27 @@ }""" -def to_numpy(tensor): +def _to_numpy(tensor): return tensor.cpu().numpy() def export_image_encoder( model_type: str, output_root: Union[str, os.PathLike], - checkpoint_path: Optional[str] = None, export_name: Optional[str] = None, -): + checkpoint_path: Optional[str] = None, +) -> None: + """Export the SAM image encoder to torchscript. + + The torchscript image encoder can be used for predicting image embeddings + with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner). + + Args: + model_type: The SAM model type. + output_root: The output root directory where the SAM model is saved. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + """ if export_name is None: export_name = model_type name = f"sam-{export_name}-encoder" @@ -91,19 +102,35 @@ def export_image_encoder( f.write(ENCODER_CONFIG % name) -# ONNX export script adapted from -# https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py def export_onnx_model( model_type, output_root, opset: int, + export_name: Optional[str] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_single_mask: bool = True, gelu_approximate: bool = False, use_stability_score: bool = False, return_extra_metrics: bool = False, - export_name: Optional[str] = None, -): +) -> None: + """Export the SAM prompt enocer and mask decoder to onnx. + + The onnx encoder and decoder can be used for interactive segmentation in the browser. + This code is adapted from + https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py + + Args: + model_type: The SAM model type. + output_root: The output root directory where the SAM model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ if export_name is None: export_name = model_type name = f"sam-{export_name}-decoder" @@ -168,7 +195,7 @@ def export_onnx_model( ) if onnxruntime_exists: - ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} + ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()} # set cpu provider default providers = ["CPUExecutionProvider"] ort_session = onnxruntime.InferenceSession(weight_path, providers=providers) @@ -184,22 +211,40 @@ def export_bioengine_model( model_type, output_root, opset: int, + export_name: Optional[str] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_single_mask: bool = True, gelu_approximate: bool = False, use_stability_score: bool = False, return_extra_metrics: bool = False, - export_name: Optional[str] = None, -): - export_image_encoder(model_type, output_root, checkpoint_path, export_name) +) -> None: + """Export the SAM model to a format compatible with the BioEngine. + + [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the + image encoder on an online backend, so that SAM can be used in an online tool, or to predict + the image embeddings via the online backend rather than on CPU. + + Args: + model_type: The SAM model type. + output_root: The output root directory where the SAM model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + export_image_encoder(model_type, output_root, export_name, checkpoint_path) export_onnx_model( model_type=model_type, output_root=output_root, opset=opset, + export_name=export_name, checkpoint_path=checkpoint_path, return_single_mask=return_single_mask, gelu_approximate=gelu_approximate, use_stability_score=use_stability_score, return_extra_metrics=return_extra_metrics, - export_name=export_name ) From ebba7191284d06e218c853949b041b9d407ab2ac Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 12 Oct 2023 18:29:13 +0200 Subject: [PATCH 12/21] Update modelzoo export script --- micro_sam/modelzoo/bioengine_export.py | 14 ++-- micro_sam/modelzoo/bioimageio_export.py | 91 ++++++++++++++++++------- micro_sam/modelzoo/predictor_adaptor.py | 28 +++++--- 3 files changed, 92 insertions(+), 41 deletions(-) diff --git a/micro_sam/modelzoo/bioengine_export.py b/micro_sam/modelzoo/bioengine_export.py index 3201920b3..9559f97d5 100644 --- a/micro_sam/modelzoo/bioengine_export.py +++ b/micro_sam/modelzoo/bioengine_export.py @@ -69,16 +69,16 @@ def export_image_encoder( export_name: Optional[str] = None, checkpoint_path: Optional[str] = None, ) -> None: - """Export the SAM image encoder to torchscript. + """Export SAM image encoder to torchscript. The torchscript image encoder can be used for predicting image embeddings with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner). Args: model_type: The SAM model type. - output_root: The output root directory where the SAM model is saved. + output_root: The output root directory where the exported model is saved. export_name: The name of the exported model. - checkpoint_path: Optional checkpoint for loading the SAM model. + checkpoint_path: Optional checkpoint for loading the exported model. """ if export_name is None: export_name = model_type @@ -113,7 +113,7 @@ def export_onnx_model( use_stability_score: bool = False, return_extra_metrics: bool = False, ) -> None: - """Export the SAM prompt enocer and mask decoder to onnx. + """Export SAM prompt enocer and mask decoder to onnx. The onnx encoder and decoder can be used for interactive segmentation in the browser. This code is adapted from @@ -121,7 +121,7 @@ def export_onnx_model( Args: model_type: The SAM model type. - output_root: The output root directory where the SAM model is saved. + output_root: The output root directory where the exported model is saved. opset: The ONNX opset version. export_name: The name of the exported model. checkpoint_path: Optional checkpoint for loading the SAM model. @@ -218,7 +218,7 @@ def export_bioengine_model( use_stability_score: bool = False, return_extra_metrics: bool = False, ) -> None: - """Export the SAM model to a format compatible with the BioEngine. + """Export SAM model to a format compatible with the BioEngine. [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the image encoder on an online backend, so that SAM can be used in an online tool, or to predict @@ -226,7 +226,7 @@ def export_bioengine_model( Args: model_type: The SAM model type. - output_root: The output root directory where the SAM model is saved. + output_root: The output root directory where the exported model is saved. opset: The ONNX opset version. export_name: The name of the exported model. checkpoint_path: Optional checkpoint for loading the SAM model. diff --git a/micro_sam/modelzoo/bioimageio_export.py b/micro_sam/modelzoo/bioimageio_export.py index 6d4ae3927..cb6f3cfcf 100644 --- a/micro_sam/modelzoo/bioimageio_export.py +++ b/micro_sam/modelzoo/bioimageio_export.py @@ -21,8 +21,17 @@ def _get_model(image, model_type, checkpoint_path): return predictor, sam_model -def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, - tmp_input_path, tmp_boxes_path, tmp_output_path): +def _create_test_inputs_and_outputs( + image, + labels, + model_type, + checkpoint_path, + input_path, + box_path, + mask_path, + score_path, + embed_path, +): # For now we just generate a single box prompt here, but we could also generate more input prompts. generator = PointAndBoxPromptGenerator(0, 0, 4, False, True) @@ -31,70 +40,103 @@ def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, _, _, box_prompts, _ = generator(masks, [bounding_boxes[1]]) box_prompts = box_prompts.numpy() - save_image_path = tmp_input_path.name + save_image_path = input_path.name np.save(save_image_path, image[None, None]) _, sam_model = _get_model(image, model_type, checkpoint_path) predictor = PredictorAdaptor(sam_model=sam_model) - save_box_prompt_path = tmp_boxes_path.name + save_box_prompt_path = box_path.name np.save(save_box_prompt_path, box_prompts) input_ = util._to_image(image).transpose(2, 0, 1) - # TODO embeddings are also expected output - instances = predictor( + masks, scores, embeddings = predictor( input_image=torch.from_numpy(input_)[None], image_embeddings=None, box_prompts=torch.from_numpy(box_prompts)[None] ) - save_output_path = tmp_output_path.name - np.save(save_output_path, instances.numpy()) + np.save(mask_path.name, masks.numpy()) + np.save(score_path.name, scores.numpy()) + np.save(embed_path.name, embeddings.numpy()) - return [save_image_path, save_box_prompt_path], [save_output_path] + return [save_image_path, save_box_prompt_path], [mask_path.name, score_path.name, embed_path.name] -def _get_documentation(doc_path): +def _write_documentation(doc_path, doc): with open(doc_path, "w") as f: - f.write("# Segment Anything for Microscopy\n") - f.write("We extend Segment Anything, a vision foundation model for image segmentation ") - f.write("by training specialized models for microscopy data.\n") + if doc is None: + f.write("# Segment Anything for Microscopy\n") + f.write("We extend Segment Anything, a vision foundation model for image segmentation ") + f.write("by training specialized models for microscopy data.\n") + else: + f.write(doc) return doc_path +# TODO enable over-riding the authors and citation and tags from kwargs +# TODO support RGB sample inputs def export_bioimageio_model( image: np.ndarray, label_image: np.ndarray, model_type: str, model_name: str, output_path: Union[str, os.PathLike], - doc_path: Optional[Union[str, os.PathLike]] = None, + doc: Optional[str] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, -): + **kwargs +) -> None: + """Export SAM model to BioImage.IO model format. + + The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and + be used in tools that support the BioImage.IO model format. + + Args: + image: The image for generating test data. + label_image: The segmentation correspoding to `image`. + It is used to derive prompt inputs for the model. + model_type: The type of the SAM model. + model_name: The name of the exported model. + output_path: Where the exported model is saved. + doc: Documentation for the model. + checkpoint_path: Optional checkpoint for loading the SAM model. + kwargs: optional keyword arguments for the 'build_model' function + that converts to the modelzoo format. + """ with ( tmp_file(suffix=".md") as tmp_doc_path, tmp_file(suffix=".npy") as tmp_input_path, tmp_file(suffix=".npy") as tmp_boxes_path, - tmp_file(suffix=".npy") as tmp_output_path + tmp_file(suffix=".npy") as tmp_mask_path, + tmp_file(suffix=".npy") as tmp_score_path, + tmp_file(suffix=".npy") as tmp_embed_path, ): input_paths, result_paths = _create_test_inputs_and_outputs( - image, label_image, model_type, checkpoint_path, tmp_input_path, tmp_boxes_path, tmp_output_path + image, label_image, model_type, checkpoint_path, + input_path=tmp_input_path, + box_path=tmp_boxes_path, + mask_path=tmp_mask_path, + score_path=tmp_score_path, + embed_path=tmp_embed_path, ) checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path) architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") - if doc_path is None: - doc_path = tmp_doc_path.name - _get_documentation(doc_path) + doc_path = tmp_doc_path.name + _write_documentation(doc_path, doc) build_model( weight_uri=checkpoint, # type: ignore test_inputs=input_paths, test_outputs=result_paths, - input_axes=["bcyx"], - output_axes=["bcyx"], + input_axes=["bcyx", "bic"], + # FIXME this causes some error in build-model + # input_names=["image", "box-prompts"], + output_axes=["bcyx", "bic", "bcyx"], + # FIXME this causes some error in build-model + # output_names=["masks", "scores", "image_embeddings"], name=model_name, description="Finetuned Segment Anything models for Microscopy", authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, @@ -105,5 +147,8 @@ def export_bioimageio_model( cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], output_path=output_path, # type: ignore - architecture=architecture_path + architecture=architecture_path, + **kwargs, ) + + # TODO actually test the model diff --git a/micro_sam/modelzoo/predictor_adaptor.py b/micro_sam/modelzoo/predictor_adaptor.py index 8fc4a058a..548f3699a 100644 --- a/micro_sam/modelzoo/predictor_adaptor.py +++ b/micro_sam/modelzoo/predictor_adaptor.py @@ -22,28 +22,34 @@ def __call__( - image_embeddings: precomputed image embeddings - box_prompts: box prompts of dimensions C x 4 """ - if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed + # We have image embeddings set and image embeddings were not passed. + if self.is_image_set and image_embeddings is None: pass # do nothing + + # We have image embeddings set and image embeddings were passed. elif self.is_image_set and image_embeddings is not None: - raise NotImplementedError # TODO: replace the image embeedings + self.features = image_embeddings + + # We don't have image embeddings set and image embeddings were passed. elif image_embeddings is not None: - pass # TODO set the image embeddings - # self.features = image_embeddings + self.features = image_embeddings + + # We don't have image embeddings set and they were not apassed elif not self.is_image_set: image = self.transform.apply_image_torch(input_image) - self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:]) # compute the image embeddings + self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:]) - boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:]) # type: ignore + boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:]) - instance_segmentation, _, _ = self.predict_torch( + masks, scores, _ = self.predict_torch( point_coords=None, point_labels=None, boxes=boxes, multimask_output=False ) - assert instance_segmentation.shape[2:] == input_image.shape[2:], f"{instance_segmentation.shape[2:]} is not as expected ({input_image.shape[2:]})" + assert masks.shape[2:] == input_image.shape[2:],\ + f"{masks.shape[2:]} is not as expected ({input_image.shape[2:]})" - # TODO get the image embeddings via image_embeddings = self.features - # and return them - return instance_segmentation + image_embeddings = self.features + return masks, scores, image_embeddings From ee831ef6857be6d04653d8ec0cf3b7671d302abd Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 13 Oct 2023 09:57:05 +0200 Subject: [PATCH 13/21] Update url in imjoy test --- examples/modelzoo/imjoy_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/modelzoo/imjoy_test.py b/examples/modelzoo/imjoy_test.py index 77e78b741..2bd500b5a 100644 --- a/examples/modelzoo/imjoy_test.py +++ b/examples/modelzoo/imjoy_test.py @@ -7,6 +7,8 @@ ) # SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io" +# SERVER_URL = "https://hypha.bioimage.io" +# SERVER_URL = "https://ai.imjoy.io" SERVER_URL = "https://hypha.bioimage.io" From 390ce234262e0c4d6517751bac6d267a90b42f05 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 14 Mar 2024 20:30:30 +0100 Subject: [PATCH 14/21] Update to bioimageio.spec v0.5 WIP --- .../modelzoo/export_model_for_bioimageio.py | 4 +- micro_sam/modelzoo/bioimageio_export.py | 257 ++++++++++++++---- micro_sam/modelzoo/predictor_adaptor.py | 99 +++++-- 3 files changed, 271 insertions(+), 89 deletions(-) diff --git a/examples/modelzoo/export_model_for_bioimageio.py b/examples/modelzoo/export_model_for_bioimageio.py index cd4214533..346a84e5b 100644 --- a/examples/modelzoo/export_model_for_bioimageio.py +++ b/examples/modelzoo/export_model_for_bioimageio.py @@ -3,11 +3,11 @@ def export_model_with_synthetic_data(): - image, labels = synthetic_data(shape=(1024, 1024)) + image, labels = synthetic_data(shape=(1024, 1022)) export_bioimageio_model( image, labels, - model_type="vit_b", model_name="sam-test-vit-b", + model_type="vit_b", name="sam-test-vit-b", output_path="./test_export.zip", ) diff --git a/micro_sam/modelzoo/bioimageio_export.py b/micro_sam/modelzoo/bioimageio_export.py index cb6f3cfcf..cd6eab90d 100644 --- a/micro_sam/modelzoo/bioimageio_export.py +++ b/micro_sam/modelzoo/bioimageio_export.py @@ -1,24 +1,30 @@ import os +from pathlib import Path from tempfile import NamedTemporaryFile as tmp_file -import numpy as np from typing import Optional, Union +import bioimageio.spec.model.v0_5 as spec +import numpy as np import torch -from bioimageio.core.build_spec import build_model +from bioimageio.spec import save_bioimageio_package + from .. import util from ..prompt_generators import PointAndBoxPromptGenerator from .predictor_adaptor import PredictorAdaptor - -def _get_model(image, model_type, checkpoint_path): - "Returns the model and predictor while initializing with the model checkpoints" - predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True, - checkpoint_path=checkpoint_path) # type: ignore - image_embeddings = util.precompute_image_embeddings(predictor, image) - util.set_precomputed(predictor, image_embeddings) - return predictor, sam_model +# TODO extend the defaults +DEFAULTS = { + "authors": [ + spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), + spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"), + ], + "description": "Finetuned Segment Anything Model for Microscopy", + "cite": [ + spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")), + ] +} def _create_test_inputs_and_outputs( @@ -36,34 +42,45 @@ def _create_test_inputs_and_outputs( # For now we just generate a single box prompt here, but we could also generate more input prompts. generator = PointAndBoxPromptGenerator(0, 0, 4, False, True) centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) - masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1]) # type: ignore - _, _, box_prompts, _ = generator(masks, [bounding_boxes[1]]) - box_prompts = box_prompts.numpy() - - save_image_path = input_path.name - np.save(save_image_path, image[None, None]) + masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore + _, _, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) + box_prompts = box_prompts.numpy()[None] - _, sam_model = _get_model(image, model_type, checkpoint_path) - predictor = PredictorAdaptor(sam_model=sam_model) + predictor = PredictorAdaptor(model_type=model_type) + predictor.load_state_dict(torch.load(checkpoint_path)) save_box_prompt_path = box_path.name np.save(save_box_prompt_path, box_prompts) - input_ = util._to_image(image).transpose(2, 0, 1) + input_ = util._to_image(image).transpose(2, 0, 1)[None] + save_image_path = input_path.name + np.save(save_image_path, input_) masks, scores, embeddings = predictor( - input_image=torch.from_numpy(input_)[None], - image_embeddings=None, - box_prompts=torch.from_numpy(box_prompts)[None] + image=torch.from_numpy(input_), + embeddings=None, + box_prompts=torch.from_numpy(box_prompts) ) np.save(mask_path.name, masks.numpy()) np.save(score_path.name, scores.numpy()) np.save(embed_path.name, embeddings.numpy()) - return [save_image_path, save_box_prompt_path], [mask_path.name, score_path.name, embed_path.name] + # TODO autogenerate the cover and return it too. + + inputs = { + "image": save_image_path, + "box_prompts": save_box_prompt_path, + } + outputs = { + "mask": mask_path.name, + "score": score_path.name, + "embeddings": embed_path.name + } + return inputs, outputs +# TODO url with documentation for the modelzoo interface, and just add it to defaults def _write_documentation(doc_path, doc): with open(doc_path, "w") as f: if doc is None: @@ -75,15 +92,19 @@ def _write_documentation(doc_path, doc): return doc_path -# TODO enable over-riding the authors and citation and tags from kwargs -# TODO support RGB sample inputs +def _get_checkpoint(model_type, checkpoint_path): + if checkpoint_path is None: + model_registry = util.models() + checkpoint_path = model_registry.fetch(model_type) + return checkpoint_path + + def export_bioimageio_model( image: np.ndarray, label_image: np.ndarray, model_type: str, - model_name: str, + name: str, output_path: Union[str, os.PathLike], - doc: Optional[str] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, **kwargs ) -> None: @@ -97,12 +118,9 @@ def export_bioimageio_model( label_image: The segmentation correspoding to `image`. It is used to derive prompt inputs for the model. model_type: The type of the SAM model. - model_name: The name of the exported model. + name: The name of the exported model. output_path: Where the exported model is saved. - doc: Documentation for the model. checkpoint_path: Optional checkpoint for loading the SAM model. - kwargs: optional keyword arguments for the 'build_model' function - that converts to the modelzoo format. """ with ( tmp_file(suffix=".md") as tmp_doc_path, @@ -112,6 +130,7 @@ def export_bioimageio_model( tmp_file(suffix=".npy") as tmp_score_path, tmp_file(suffix=".npy") as tmp_embed_path, ): + checkpoint_path = _get_checkpoint(model_type, checkpoint_path=checkpoint_path) input_paths, result_paths = _create_test_inputs_and_outputs( image, label_image, model_type, checkpoint_path, input_path=tmp_input_path, @@ -120,35 +139,159 @@ def export_bioimageio_model( score_path=tmp_score_path, embed_path=tmp_embed_path, ) - checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path) + input_descriptions = [ + # First input: the image data. + spec.InputTensorDescr( + id=spec.TensorId("image"), + axes=[ + spec.BatchAxis(), + # NOTE: to support 1 and 3 channels we can add another preprocessing. + # Best solution: Have a pre-processing for this! (1C -> RGB) + spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), + spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE), + spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE), + ], + test_tensor=spec.FileDescr(source=input_paths["image"]), + data=spec.IntervalOrRatioDataDescr(type="uint8") + ), + + # Second input: the box prompts (optional) + spec.InputTensorDescr( + id=spec.TensorId("box_prompts"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + # TODO double check the axis names + spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), + ], + test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), + data=spec.IntervalOrRatioDataDescr(type="int64") + ), + + # TODO + # Third input: the point prompts (optional) + # TODO + # Fourth input: the mask prompts (optional) + + # Fifth input: the image embeddings (optional) + spec.InputTensorDescr( + id=spec.TensorId("embeddings"), + optional=True, + axes=[ + spec.BatchAxis(), + # NOTE: we currently have to specify all the channel names + # (It would be nice to also support size) + spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), + spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), + spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), + ], + test_tensor=spec.FileDescr(source=result_paths["embeddings"]), + data=spec.IntervalOrRatioDataDescr(type="float32") + ), + + ] + + output_descriptions = [ + # First output: The mask predictions. + spec.OutputTensorDescr( + id=spec.TensorId("masks"), + axes=[ + spec.BatchAxis(), + spec.IndexAxis( + id=spec.AxisId("object"), + size=spec.SizeReference( + tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object") + ) + ), + # NOTE: this could be a 3 once we use multi-masking + spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), + spec.SpaceOutputAxis( + id=spec.AxisId("y"), + size=spec.SizeReference( + tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), + ) + ), + spec.SpaceOutputAxis( + id=spec.AxisId("x"), + size=spec.SizeReference( + tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), + ) + ) + ], + data=spec.IntervalOrRatioDataDescr(type="uint8"), + test_tensor=spec.FileDescr(source=result_paths["mask"]) + ), + + # The score predictions + spec.OutputTensorDescr( + id=spec.TensorId("scores"), + axes=[ + spec.BatchAxis(), + spec.IndexAxis( + id=spec.AxisId("object"), + size=spec.SizeReference( + tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object") + ) + ), + # NOTE: this could be a 3 once we use multi-masking + spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), + ], + data=spec.IntervalOrRatioDataDescr(type="float32"), + test_tensor=spec.FileDescr(source=result_paths["score"]) + ), + + # The image embeddings + spec.OutputTensorDescr( + id=spec.TensorId("embeddings"), + axes=[ + spec.BatchAxis(), + spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), + spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), + spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), + ], + data=spec.IntervalOrRatioDataDescr(type="float32"), + test_tensor=spec.FileDescr(source=result_paths["embeddings"]) + ) + ] + + # TODO sha256 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") + architecture = spec.ArchitectureFromFileDescr( + source=Path(architecture_path), + callable="PredictorAdaptor", + kwargs={"model_type": model_type} + ) + + weight_descriptions = spec.WeightsDescr( + pytorch_state_dict=spec.PytorchStateDictWeightsDescr( + source=Path(checkpoint_path), + architecture=architecture, + pytorch_version=spec.Version(torch.__version__), + ) + ) doc_path = tmp_doc_path.name - _write_documentation(doc_path, doc) - - build_model( - weight_uri=checkpoint, # type: ignore - test_inputs=input_paths, - test_outputs=result_paths, - input_axes=["bcyx", "bic"], - # FIXME this causes some error in build-model - # input_names=["image", "box-prompts"], - output_axes=["bcyx", "bic", "bcyx"], - # FIXME this causes some error in build-model - # output_names=["masks", "scores", "image_embeddings"], - name=model_name, - description="Finetuned Segment Anything models for Microscopy", - authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, - {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], - tags=["instance-segmentation", "segment-anything"], - license="CC-BY-4.0", - documentation=doc_path, # type: ignore - cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", - "doi": "10.1101/2023.08.21.554208"}], - output_path=output_path, # type: ignore - architecture=architecture_path, - **kwargs, + _write_documentation(doc_path, kwargs.get("documentation", None)) + + # TODO tags, dependencies, other stuff ... + model_description = spec.ModelDescr( + name=name, + description=kwargs.get("description", DEFAULTS["description"]), + authors=kwargs.get("authors", DEFAULTS["authors"]), + cite=kwargs.get("cite", DEFAULTS["cite"]), + license=spec.LicenseId("MIT"), + documentation=Path(doc_path), + git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), + inputs=input_descriptions, + outputs=output_descriptions, + weights=weight_descriptions, ) - # TODO actually test the model + # TODO test the model. + + save_bioimageio_package(model_description, output_path=output_path) diff --git a/micro_sam/modelzoo/predictor_adaptor.py b/micro_sam/modelzoo/predictor_adaptor.py index 548f3699a..13ec32835 100644 --- a/micro_sam/modelzoo/predictor_adaptor.py +++ b/micro_sam/modelzoo/predictor_adaptor.py @@ -1,55 +1,94 @@ -from typing import Optional +import warnings +from typing import Optional, Tuple import torch +from torch import nn from segment_anything.predictor import SamPredictor +try: + # Avoid import warnings from mobile_sam + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from mobile_sam import sam_model_registry +except ImportError: + from segment_anything import sam_model_registry -class PredictorAdaptor(SamPredictor): - """Wrapper around the SamPredictor to be used by BioImage.IO model format. + +# TODO we need to accept and return an additional tensor for the image sizes to support embeddings +class PredictorAdaptor(nn.Module): + """Wrapper around the SamPredictor. This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts. + + Args: + model_type: The type of the model for the image encoder. + Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. + For 'vit_t' support the 'mobile_sam' package has to be installed. """ - def __call__( - self, - input_image: torch.Tensor, - image_embeddings: Optional[torch.Tensor] = None, - box_prompts: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Expected inputs: - - input_image: torch inputs of dimensions B x C x H x W - - image_embeddings: precomputed image embeddings - - box_prompts: box prompts of dimensions C x 4 + def __init__(self, model_type: str) -> None: + super().__init__() + sam_model = sam_model_registry[model_type]() + self.sam = SamPredictor(sam_model) + + def load_state_dict(self, state): + self.sam.model.load_state_dict(state) + + @torch.no_grad() + def forward( + self, + image: torch.Tensor, + box_prompts: Optional[torch.Tensor] = None, + # TODO add point and mask prompts + embeddings: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + + Args: + image: torch inputs of dimensions B x C x H x W + box_prompts: box prompts of dimensions B x OBJECTS x 4 + embeddings: precomputed image embeddings B x 256 x 64 x 64 + + Returns: """ + batch_size = image.shape[0] + if batch_size != 1: + raise ValueError + # We have image embeddings set and image embeddings were not passed. - if self.is_image_set and image_embeddings is None: + if self.sam.is_image_set and embeddings is None: pass # do nothing - # We have image embeddings set and image embeddings were passed. - elif self.is_image_set and image_embeddings is not None: - self.features = image_embeddings - - # We don't have image embeddings set and image embeddings were passed. - elif image_embeddings is not None: - self.features = image_embeddings + # The embeddings are passed, so we set them. + elif embeddings is not None: + self.sam.features = embeddings + self.sam.orig_h, self.sam.orig_w = image.shape[2:] + self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] + self.sam.is_image_set = True # We don't have image embeddings set and they were not apassed - elif not self.is_image_set: - image = self.transform.apply_image_torch(input_image) - self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:]) + elif not self.sam.is_image_set: + image = self.sam.transform.apply_image_torch(image) + self.sam.set_torch_image(image, original_image_size=image.numpy().shape[2:]) - boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:]) + boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=image.numpy().shape[2:]) - masks, scores, _ = self.predict_torch( + masks, scores, _ = self.sam.predict_torch( point_coords=None, point_labels=None, boxes=boxes, multimask_output=False ) - assert masks.shape[2:] == input_image.shape[2:],\ - f"{masks.shape[2:]} is not as expected ({input_image.shape[2:]})" + assert masks.shape[2:] == image.shape[2:], \ + f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" + + # Ensure batch axis. + if masks.ndim == 4: + masks = masks[None] + assert scores.ndim == 2 + scores = scores[None] - image_embeddings = self.features - return masks, scores, image_embeddings + embeddings = self.sam.get_image_embedding() + return masks.to(dtype=torch.uint8), scores, embeddings From 6cccc0608e0c7d605a7cb17c98852c7c6a343730 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 15 Mar 2024 17:37:26 +0100 Subject: [PATCH 15/21] Update example script --- examples/modelzoo/export_model_for_bioimageio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modelzoo/export_model_for_bioimageio.py b/examples/modelzoo/export_model_for_bioimageio.py index 346a84e5b..e42cde478 100644 --- a/examples/modelzoo/export_model_for_bioimageio.py +++ b/examples/modelzoo/export_model_for_bioimageio.py @@ -7,7 +7,7 @@ def export_model_with_synthetic_data(): export_bioimageio_model( image, labels, - model_type="vit_b", name="sam-test-vit-b", + model_type="vit_t", name="sam-test-vit-t", output_path="./test_export.zip", ) From b9df849d15016aac32b05ee3a0413e8f4fe71927 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 18 Mar 2024 19:53:33 +0100 Subject: [PATCH 16/21] Update bioimageio export --- .../export_model_for_bioengine.py | 0 .../export_model_for_bioimageio.py | 4 +- .../{modelzoo => bioimageio}/imjoy_test.py | 0 micro_sam/bioimageio/__init__.py | 1 + .../bioengine_export.py | 0 .../model_export.py} | 134 ++++++++++++------ .../predictor_adaptor.py | 0 micro_sam/evaluation/model_comparison.py | 2 +- micro_sam/modelzoo/__init__.py | 2 - test/test_bioimageio/test_model_export.py | 37 +++++ 10 files changed, 134 insertions(+), 46 deletions(-) rename examples/{modelzoo => bioimageio}/export_model_for_bioengine.py (100%) rename examples/{modelzoo => bioimageio}/export_model_for_bioimageio.py (81%) rename examples/{modelzoo => bioimageio}/imjoy_test.py (100%) create mode 100644 micro_sam/bioimageio/__init__.py rename micro_sam/{modelzoo => bioimageio}/bioengine_export.py (100%) rename micro_sam/{modelzoo/bioimageio_export.py => bioimageio/model_export.py} (76%) rename micro_sam/{modelzoo => bioimageio}/predictor_adaptor.py (100%) delete mode 100644 micro_sam/modelzoo/__init__.py create mode 100644 test/test_bioimageio/test_model_export.py diff --git a/examples/modelzoo/export_model_for_bioengine.py b/examples/bioimageio/export_model_for_bioengine.py similarity index 100% rename from examples/modelzoo/export_model_for_bioengine.py rename to examples/bioimageio/export_model_for_bioengine.py diff --git a/examples/modelzoo/export_model_for_bioimageio.py b/examples/bioimageio/export_model_for_bioimageio.py similarity index 81% rename from examples/modelzoo/export_model_for_bioimageio.py rename to examples/bioimageio/export_model_for_bioimageio.py index e42cde478..5769ef7c2 100644 --- a/examples/modelzoo/export_model_for_bioimageio.py +++ b/examples/bioimageio/export_model_for_bioimageio.py @@ -1,11 +1,11 @@ -from micro_sam.modelzoo import export_bioimageio_model +from micro_sam.bioimageio import export_sam_model from micro_sam.sample_data import synthetic_data def export_model_with_synthetic_data(): image, labels = synthetic_data(shape=(1024, 1022)) - export_bioimageio_model( + export_sam_model( image, labels, model_type="vit_t", name="sam-test-vit-t", output_path="./test_export.zip", diff --git a/examples/modelzoo/imjoy_test.py b/examples/bioimageio/imjoy_test.py similarity index 100% rename from examples/modelzoo/imjoy_test.py rename to examples/bioimageio/imjoy_test.py diff --git a/micro_sam/bioimageio/__init__.py b/micro_sam/bioimageio/__init__.py new file mode 100644 index 000000000..f3534f83c --- /dev/null +++ b/micro_sam/bioimageio/__init__.py @@ -0,0 +1 @@ +from .model_export import export_sam_model diff --git a/micro_sam/modelzoo/bioengine_export.py b/micro_sam/bioimageio/bioengine_export.py similarity index 100% rename from micro_sam/modelzoo/bioengine_export.py rename to micro_sam/bioimageio/bioengine_export.py diff --git a/micro_sam/modelzoo/bioimageio_export.py b/micro_sam/bioimageio/model_export.py similarity index 76% rename from micro_sam/modelzoo/bioimageio_export.py rename to micro_sam/bioimageio/model_export.py index cd6eab90d..2aecfea55 100644 --- a/micro_sam/modelzoo/bioimageio_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -1,9 +1,11 @@ import os +import tempfile + from pathlib import Path -from tempfile import NamedTemporaryFile as tmp_file from typing import Optional, Union import bioimageio.spec.model.v0_5 as spec +import matplotlib.pyplot as plt import numpy as np import torch @@ -12,9 +14,9 @@ from .. import util from ..prompt_generators import PointAndBoxPromptGenerator +from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box from .predictor_adaptor import PredictorAdaptor -# TODO extend the defaults DEFAULTS = { "authors": [ spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), @@ -23,7 +25,8 @@ "description": "Finetuned Segment Anything Model for Microscopy", "cite": [ spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")), - ] + ], + "tags": ["segment-anything", "instance-segmentation"] } @@ -32,11 +35,7 @@ def _create_test_inputs_and_outputs( labels, model_type, checkpoint_path, - input_path, - box_path, - mask_path, - score_path, - embed_path, + tmp_dir, ): # For now we just generate a single box prompt here, but we could also generate more input prompts. @@ -49,11 +48,11 @@ def _create_test_inputs_and_outputs( predictor = PredictorAdaptor(model_type=model_type) predictor.load_state_dict(torch.load(checkpoint_path)) - save_box_prompt_path = box_path.name + save_box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") np.save(save_box_prompt_path, box_prompts) input_ = util._to_image(image).transpose(2, 0, 1)[None] - save_image_path = input_path.name + save_image_path = os.path.join(tmp_dir, "input.npy") np.save(save_image_path, input_) masks, scores, embeddings = predictor( @@ -62,9 +61,12 @@ def _create_test_inputs_and_outputs( box_prompts=torch.from_numpy(box_prompts) ) - np.save(mask_path.name, masks.numpy()) - np.save(score_path.name, scores.numpy()) - np.save(embed_path.name, embeddings.numpy()) + mask_path = os.path.join(tmp_dir, "mask.npy") + score_path = os.path.join(tmp_dir, "scores.npy") + embed_path = os.path.join(tmp_dir, "embeddings.npy") + np.save(mask_path, masks.numpy()) + np.save(score_path, scores.numpy()) + np.save(embed_path, embeddings.numpy()) # TODO autogenerate the cover and return it too. @@ -73,9 +75,9 @@ def _create_test_inputs_and_outputs( "box_prompts": save_box_prompt_path, } outputs = { - "mask": mask_path.name, - "score": score_path.name, - "embeddings": embed_path.name + "mask": mask_path, + "score": score_path, + "embeddings": embed_path } return inputs, outputs @@ -99,7 +101,56 @@ def _get_checkpoint(model_type, checkpoint_path): return checkpoint_path -def export_bioimageio_model( +def _write_dependencies(dependency_file, require_mobile_sam): + content = """name: sam +channels: + - pytorch + - conda-forge +dependencies: + - segment-anything""" + if require_mobile_sam: + content += """ + - pip: + - git+https://github.com/ChaoningZhang/MobileSAM.git""" + with open(dependency_file, "w") as f: + f.write(content) + + +def _generate_covers(input_paths, result_paths, tmp_dir): + image = np.load(input_paths["image"]).squeeze() + prompts = np.load(input_paths["box_prompts"]) + mask = np.load(result_paths["mask"]) + + # create the image overlay + if image.ndim == 2: + overlay = np.stack([image, image, image]).transpose((1, 2, 0)) + elif image.shape[0] == 3: + overlay = image.transpose((1, 2, 0)) + else: + overlay = image + overlay = _enhance_image(overlay.astype("float32")) + + # overlay the mask as outline + overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2) + + # overlay the bounding box prompt + prompt = prompts[0, 0][[1, 0, 3, 2]] + prompt = np.array([prompt[:2], prompt[2:]]) + overlay = _overlay_box(overlay, prompt, outline_dilation=4) + + # write the cover image + fig, ax = plt.subplots(1) + ax.axis("off") + ax.imshow(overlay.astype("uint8")) + cover_path = os.path.join(tmp_dir, "cover.jpeg") + plt.savefig(cover_path, bbox_inches="tight") + plt.close() + + covers = [cover_path] + return covers + + +def export_sam_model( image: np.ndarray, label_image: np.ndarray, model_type: str, @@ -122,22 +173,10 @@ def export_bioimageio_model( output_path: Where the exported model is saved. checkpoint_path: Optional checkpoint for loading the SAM model. """ - with ( - tmp_file(suffix=".md") as tmp_doc_path, - tmp_file(suffix=".npy") as tmp_input_path, - tmp_file(suffix=".npy") as tmp_boxes_path, - tmp_file(suffix=".npy") as tmp_mask_path, - tmp_file(suffix=".npy") as tmp_score_path, - tmp_file(suffix=".npy") as tmp_embed_path, - ): + with tempfile.TemporaryDirectory() as tmp_dir: checkpoint_path = _get_checkpoint(model_type, checkpoint_path=checkpoint_path) input_paths, result_paths = _create_test_inputs_and_outputs( - image, label_image, model_type, checkpoint_path, - input_path=tmp_input_path, - box_path=tmp_boxes_path, - mask_path=tmp_mask_path, - score_path=tmp_score_path, - embed_path=tmp_embed_path, + image, label_image, model_type, checkpoint_path, tmp_dir, ) input_descriptions = [ # First input: the image data. @@ -161,7 +200,7 @@ def export_bioimageio_model( optional=True, axes=[ spec.BatchAxis(), - spec.IndexAxis( + spec.IndexInputAxis( id=spec.AxisId("object"), size=spec.ARBITRARY_SIZE ), @@ -202,7 +241,7 @@ def export_bioimageio_model( id=spec.TensorId("masks"), axes=[ spec.BatchAxis(), - spec.IndexAxis( + spec.IndexOutputAxis( id=spec.AxisId("object"), size=spec.SizeReference( tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object") @@ -232,7 +271,7 @@ def export_bioimageio_model( id=spec.TensorId("scores"), axes=[ spec.BatchAxis(), - spec.IndexAxis( + spec.IndexOutputAxis( id=spec.AxisId("object"), size=spec.SizeReference( tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object") @@ -259,7 +298,6 @@ def export_bioimageio_model( ) ] - # TODO sha256 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") architecture = spec.ArchitectureFromFileDescr( source=Path(architecture_path), @@ -267,29 +305,43 @@ def export_bioimageio_model( kwargs={"model_type": model_type} ) + dependency_file = os.path.join(tmp_dir, "environment.yaml") + _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) + # print(dependency_file) + # breakpoint() + weight_descriptions = spec.WeightsDescr( pytorch_state_dict=spec.PytorchStateDictWeightsDescr( source=Path(checkpoint_path), architecture=architecture, pytorch_version=spec.Version(torch.__version__), + # FIXME: this leads to a validation error! + # dependencies=dependency_file, ) ) - doc_path = tmp_doc_path.name + doc_path = os.path.join(tmp_dir, "documentation.md") _write_documentation(doc_path, kwargs.get("documentation", None)) - # TODO tags, dependencies, other stuff ... + covers = _generate_covers(input_paths, result_paths, tmp_dir) + model_description = spec.ModelDescr( name=name, + inputs=input_descriptions, + outputs=output_descriptions, + weights=weight_descriptions, description=kwargs.get("description", DEFAULTS["description"]), authors=kwargs.get("authors", DEFAULTS["authors"]), cite=kwargs.get("cite", DEFAULTS["cite"]), - license=spec.LicenseId("MIT"), + license=spec.LicenseId("CC-BY-4.0"), documentation=Path(doc_path), git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), - inputs=input_descriptions, - outputs=output_descriptions, - weights=weight_descriptions, + tags=kwargs.get("tags", DEFAULTS["tags"]), + covers=covers, + # TODO attach the decoder weights if given + # attachments= + # TODO write the config + # config= ) # TODO test the model. diff --git a/micro_sam/modelzoo/predictor_adaptor.py b/micro_sam/bioimageio/predictor_adaptor.py similarity index 100% rename from micro_sam/modelzoo/predictor_adaptor.py rename to micro_sam/bioimageio/predictor_adaptor.py diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 6df9e8f40..478ab4584 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -126,7 +126,7 @@ def generate_data_for_model_comparison( # -# Visual evaluation accroding to metrics +# Visual evaluation according to metrics # diff --git a/micro_sam/modelzoo/__init__.py b/micro_sam/modelzoo/__init__.py deleted file mode 100644 index a66dd2da9..000000000 --- a/micro_sam/modelzoo/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bioimageio_export import export_bioimageio_model -from .bioengine_export import export_bioengine_model diff --git a/test/test_bioimageio/test_model_export.py b/test/test_bioimageio/test_model_export.py new file mode 100644 index 000000000..4fbd882c1 --- /dev/null +++ b/test/test_bioimageio/test_model_export.py @@ -0,0 +1,37 @@ +import os +import unittest + +from shutil import rmtree + +import micro_sam.util as util +from micro_sam.sample_data import synthetic_data + + +class TestModelExport(unittest.TestCase): + tmp_folder = "tmp" + model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + + def tearDown(self): + rmtree(self.tmp_folder) + + def test_model_export(self): + from micro_sam.bioimageio import export_sam_model + image, labels = synthetic_data(shape=(1024, 1022)) + + export_path = os.path.join(self.tmp_folder, "test_export.zip") + export_sam_model( + image, labels, + model_type=self.model_type, name="test-export", + output_path=export_path, + ) + + self.assertTrue(os.path.exists(export_path)) + + # TODO more tests: run prediction with models for different prompt settings + + +if __name__ == "__main__": + unittest.main() From d911392294f1658535146d8dea99ec5a7b9b7475 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 19 Mar 2024 15:47:09 +0100 Subject: [PATCH 17/21] Minor fixes --- micro_sam/bioimageio/model_export.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 2aecfea55..81145dda6 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -241,11 +241,9 @@ def export_sam_model( id=spec.TensorId("masks"), axes=[ spec.BatchAxis(), + # NOTE: we use the data dependent size here to avoid dependency on optional inputs spec.IndexOutputAxis( - id=spec.AxisId("object"), - size=spec.SizeReference( - tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object") - ) + id=spec.AxisId("object"), size=spec.DataDependentSize(), ), # NOTE: this could be a 3 once we use multi-masking spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), @@ -271,11 +269,9 @@ def export_sam_model( id=spec.TensorId("scores"), axes=[ spec.BatchAxis(), + # NOTE: we use the data dependent size here to avoid dependency on optional inputs spec.IndexOutputAxis( - id=spec.AxisId("object"), - size=spec.SizeReference( - tensor_id=spec.TensorId("box_prompts"), axis_id=spec.AxisId("object") - ) + id=spec.AxisId("object"), size=spec.DataDependentSize(), ), # NOTE: this could be a 3 once we use multi-masking spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), @@ -307,16 +303,13 @@ def export_sam_model( dependency_file = os.path.join(tmp_dir, "environment.yaml") _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) - # print(dependency_file) - # breakpoint() weight_descriptions = spec.WeightsDescr( pytorch_state_dict=spec.PytorchStateDictWeightsDescr( source=Path(checkpoint_path), architecture=architecture, pytorch_version=spec.Version(torch.__version__), - # FIXME: this leads to a validation error! - # dependencies=dependency_file, + dependencies=spec.EnvironmentFileDescr(source=dependency_file), ) ) @@ -339,11 +332,15 @@ def export_sam_model( tags=kwargs.get("tags", DEFAULTS["tags"]), covers=covers, # TODO attach the decoder weights if given - # attachments= + # Can be list of files??? + # attachments=[spec.FileDescr(source=file_path) for file_path in attachment_files] # TODO write the config + # dict with yaml values, key must be a str + # micro_sam: ... # config= ) # TODO test the model. + # Should work, but not tested with optional. save_bioimageio_package(model_description, output_path=output_path) From ec5603566b78a914df61968b39ff7f2ca6045e64 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 19 Mar 2024 21:28:38 +0100 Subject: [PATCH 18/21] Work on export --- micro_sam/bioimageio/model_export.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 81145dda6..12c400858 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -4,10 +4,13 @@ from pathlib import Path from typing import Optional, Union +# FIXME import fails +import bioimageio.core import bioimageio.spec.model.v0_5 as spec import matplotlib.pyplot as plt import numpy as np import torch +import xarray from bioimageio.spec import save_bioimageio_package @@ -150,6 +153,27 @@ def _generate_covers(input_paths, result_paths, tmp_dir): return covers +def _check_model(model_description, input_paths, result_paths): + model = bioimageio.core.load_resource_description(model_description) + + # Load inputs and outputs. + image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx")) + embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx")) + box_prompts = np.load(input_paths["box_prompts"], dims=tuple("bic")) + mask = np.load(result_paths["mask"]) + + breakpoint() + + # Check with box prompt. + with bioimageio.core.create_prediction_pipeline(model) as pp: + prediction = pp.forward( + image=image, + embeddings=embeddings, + box_prompts=box_prompts, + ) + breakpoint() + + def export_sam_model( image: np.ndarray, label_image: np.ndarray, @@ -340,7 +364,6 @@ def export_sam_model( # config= ) - # TODO test the model. - # Should work, but not tested with optional. + _check_model(model_description, input_paths, result_paths) save_bioimageio_package(model_description, output_path=output_path) From 2d9c88a99c0ff1c47dfa3191c2171945d368870e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 20 Mar 2024 21:59:06 +0100 Subject: [PATCH 19/21] More modelzoo updtes --- micro_sam/bioimageio/model_export.py | 102 ++++++++++++++++------ micro_sam/bioimageio/predictor_adaptor.py | 43 +++++++-- 2 files changed, 111 insertions(+), 34 deletions(-) diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 12c400858..7ba1df866 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Optional, Union -# FIXME import fails import bioimageio.core import bioimageio.spec.model.v0_5 as spec import matplotlib.pyplot as plt @@ -42,28 +41,43 @@ def _create_test_inputs_and_outputs( ): # For now we just generate a single box prompt here, but we could also generate more input prompts. - generator = PointAndBoxPromptGenerator(0, 0, 4, False, True) + generator = PointAndBoxPromptGenerator( + n_positive_points=1, + n_negative_points=2, + dilation_strength=2, + get_point_prompts=True, + get_box_prompts=True, + ) centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore - _, _, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) + point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) + box_prompts = box_prompts.numpy()[None] + point_prompts = point_prompts.numpy()[None] + point_labels = point_labels.numpy()[None] predictor = PredictorAdaptor(model_type=model_type) predictor.load_state_dict(torch.load(checkpoint_path)) - save_box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") - np.save(save_box_prompt_path, box_prompts) - input_ = util._to_image(image).transpose(2, 0, 1)[None] - save_image_path = os.path.join(tmp_dir, "input.npy") - np.save(save_image_path, input_) + image_path = os.path.join(tmp_dir, "input.npy") + np.save(image_path, input_) masks, scores, embeddings = predictor( image=torch.from_numpy(input_), embeddings=None, - box_prompts=torch.from_numpy(box_prompts) + box_prompts=torch.from_numpy(box_prompts), + point_prompts=torch.from_numpy(point_prompts), + point_labels=torch.from_numpy(point_labels), ) + box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") + point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy") + point_label_path = os.path.join(tmp_dir, "point_labels.npy") + np.save(box_prompt_path, box_prompts) + np.save(point_prompt_path, point_prompts) + np.save(point_label_path, point_labels) + mask_path = os.path.join(tmp_dir, "mask.npy") score_path = os.path.join(tmp_dir, "scores.npy") embed_path = os.path.join(tmp_dir, "embeddings.npy") @@ -71,11 +85,11 @@ def _create_test_inputs_and_outputs( np.save(score_path, scores.numpy()) np.save(embed_path, embeddings.numpy()) - # TODO autogenerate the cover and return it too. - inputs = { - "image": save_image_path, - "box_prompts": save_box_prompt_path, + "image": image_path, + "box_prompts": box_prompt_path, + "point_prompts": point_prompt_path, + "point_labels": point_label_path, } outputs = { "mask": mask_path, @@ -154,24 +168,26 @@ def _generate_covers(input_paths, result_paths, tmp_dir): def _check_model(model_description, input_paths, result_paths): - model = bioimageio.core.load_resource_description(model_description) - # Load inputs and outputs. image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx")) embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx")) - box_prompts = np.load(input_paths["box_prompts"], dims=tuple("bic")) + box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic")) + point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic")) + point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic")) mask = np.load(result_paths["mask"]) - breakpoint() - - # Check with box prompt. - with bioimageio.core.create_prediction_pipeline(model) as pp: + # Check with box and point prompts. + with bioimageio.core.create_prediction_pipeline(model_description) as pp: prediction = pp.forward( image=image, embeddings=embeddings, box_prompts=box_prompts, + point_prompts=point_prompts, + point_labels=point_labels, ) - breakpoint() + assert len(prediction) == 3 + predicted_mask = prediction[0] + assert np.allclose(mask, predicted_mask) def export_sam_model( @@ -235,13 +251,49 @@ def export_sam_model( data=spec.IntervalOrRatioDataDescr(type="int64") ), - # TODO - # Third input: the point prompts (optional) + # Third input: the point prompt coordinates (optional) + spec.InputTensorDescr( + id=spec.TensorId("point_prompts"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + spec.IndexInputAxis( + id=spec.AxisId("point"), + size=spec.ARBITRARY_SIZE + ), + spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), + ], + test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), + data=spec.IntervalOrRatioDataDescr(type="int64") + ), + + # Fourth input: the point prompt labels (optional) + spec.InputTensorDescr( + id=spec.TensorId("point_labels"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + spec.IndexInputAxis( + id=spec.AxisId("point"), + size=spec.ARBITRARY_SIZE + ), + ], + test_tensor=spec.FileDescr(source=input_paths["point_labels"]), + data=spec.IntervalOrRatioDataDescr(type="int64") + ), # TODO - # Fourth input: the mask prompts (optional) + # Fifth input: the mask prompts (optional) - # Fifth input: the image embeddings (optional) + # Sixth input: the image embeddings (optional) spec.InputTensorDescr( id=spec.TensorId("embeddings"), optional=True, diff --git a/micro_sam/bioimageio/predictor_adaptor.py b/micro_sam/bioimageio/predictor_adaptor.py index 13ec32835..509544f53 100644 --- a/micro_sam/bioimageio/predictor_adaptor.py +++ b/micro_sam/bioimageio/predictor_adaptor.py @@ -15,7 +15,6 @@ from segment_anything import sam_model_registry -# TODO we need to accept and return an additional tensor for the image sizes to support embeddings class PredictorAdaptor(nn.Module): """Wrapper around the SamPredictor. @@ -40,14 +39,19 @@ def forward( self, image: torch.Tensor, box_prompts: Optional[torch.Tensor] = None, - # TODO add point and mask prompts + point_prompts: Optional[torch.Tensor] = None, + point_labels: Optional[torch.Tensor] = None, + mask_prompts: Optional[torch.Tensor] = None, embeddings: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: image: torch inputs of dimensions B x C x H x W - box_prompts: box prompts of dimensions B x OBJECTS x 4 + box_prompts: box coordinates of dimensions B x OBJECTS x 4 + point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 + point_labels: point labels of dimension B x OBJECTS x POINTS + mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 embeddings: precomputed image embeddings B x 256 x 64 x 64 Returns: @@ -67,17 +71,38 @@ def forward( self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] self.sam.is_image_set = True - # We don't have image embeddings set and they were not apassed + # We don't have image embeddings set and they were not passed. elif not self.sam.is_image_set: image = self.sam.transform.apply_image_torch(image) self.sam.set_torch_image(image, original_image_size=image.numpy().shape[2:]) - - boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=image.numpy().shape[2:]) - + self.sam.orig_h, self.sam.orig_w = self.sam.original_size + self.sam.input_h, self.sam.input_w = self.sam.input_size + + # Ensure input size and original size are set. + self.sam.input_size = (self.sam.input_h, self.sam.input_w) + self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) + + if box_prompts is None: + boxes = None + else: + boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) + + if point_prompts is None: + point_coords = None + else: + assert point_labels is not None + point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] + point_labels = point_labels[0] + + print() + print(point_coords.shape) + print(point_labels.shape) + print(boxes.shape) masks, scores, _ = self.sam.predict_torch( - point_coords=None, - point_labels=None, + point_coords=point_coords, + point_labels=point_labels, boxes=boxes, + mask_input=mask_prompts, multimask_output=False ) From a170511520d6aa2b09c59d2db24111f40336e936 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 21 Mar 2024 09:05:58 +0100 Subject: [PATCH 20/21] Add all possible model inputs --- micro_sam/bioimageio/model_export.py | 72 ++++++++++++++++++++--- micro_sam/bioimageio/predictor_adaptor.py | 13 ++-- 2 files changed, 73 insertions(+), 12 deletions(-) diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 7ba1df866..65909a32a 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -17,6 +17,7 @@ from .. import util from ..prompt_generators import PointAndBoxPromptGenerator from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box +from ..prompt_based_segmentation import _compute_logits_from_mask from .predictor_adaptor import PredictorAdaptor DEFAULTS = { @@ -56,6 +57,14 @@ def _create_test_inputs_and_outputs( point_prompts = point_prompts.numpy()[None] point_labels = point_labels.numpy()[None] + # Generate logits from the two + mask_prompts = np.stack( + [ + _compute_logits_from_mask(labels == 1), + _compute_logits_from_mask(labels == 2), + ] + )[None] + predictor = PredictorAdaptor(model_type=model_type) predictor.load_state_dict(torch.load(checkpoint_path)) @@ -69,14 +78,17 @@ def _create_test_inputs_and_outputs( box_prompts=torch.from_numpy(box_prompts), point_prompts=torch.from_numpy(point_prompts), point_labels=torch.from_numpy(point_labels), + mask_prompts=torch.from_numpy(mask_prompts), ) box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy") point_label_path = os.path.join(tmp_dir, "point_labels.npy") + mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy") np.save(box_prompt_path, box_prompts) np.save(point_prompt_path, point_prompts) np.save(point_label_path, point_labels) + np.save(mask_prompt_path, mask_prompts) mask_path = os.path.join(tmp_dir, "mask.npy") score_path = os.path.join(tmp_dir, "scores.npy") @@ -90,6 +102,7 @@ def _create_test_inputs_and_outputs( "box_prompts": box_prompt_path, "point_prompts": point_prompt_path, "point_labels": point_label_path, + "mask_prompts": mask_prompt_path, } outputs = { "mask": mask_path, @@ -168,26 +181,56 @@ def _generate_covers(input_paths, result_paths, tmp_dir): def _check_model(model_description, input_paths, result_paths): - # Load inputs and outputs. + # Load inputs. image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx")) embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx")) box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic")) point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic")) point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic")) + mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx")) + + # Load outputs. mask = np.load(result_paths["mask"]) - # Check with box and point prompts. with bioimageio.core.create_prediction_pipeline(model_description) as pp: + + # Check with all prompts. We only check the result for this setting, + # because this was used to generate the test data. prediction = pp.forward( image=image, - embeddings=embeddings, box_prompts=box_prompts, point_prompts=point_prompts, point_labels=point_labels, + mask_prompts=mask_prompts, + embeddings=embeddings, ) - assert len(prediction) == 3 - predicted_mask = prediction[0] - assert np.allclose(mask, predicted_mask) + + assert len(prediction) == 3 + predicted_mask = prediction[0] + assert np.allclose(mask, predicted_mask) + + # FIXME this fails due to errors with optional inputs + # Check with partial prompts. + prompt_kwargs = [ + # With boxes. + {"box_prompts": box_prompts}, + # With point prompts. + {"point_prompts": point_prompts, "point_labels": point_labels}, + # With masks. + {"mask_prompts": mask_prompts}, + # With boxes and points. + {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, + # With boxes and masks. + {"box_prompts": box_prompts, "mask_prompts": mask_prompts}, + # With points and masks. + {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, + ] + + for kwargs in prompt_kwargs: + prediction = pp.forward(image=image, embeddings=embeddings, **kwargs) + assert len(prediction) == 3 + predicted_mask = prediction[0] + assert predicted_mask.shape == mask.shape def export_sam_model( @@ -290,8 +333,23 @@ def export_sam_model( data=spec.IntervalOrRatioDataDescr(type="int64") ), - # TODO # Fifth input: the mask prompts (optional) + spec.InputTensorDescr( + id=spec.TensorId("mask_prompts"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + spec.ChannelAxis(channel_names=["channel"]), + spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), + spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), + ], + test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), + data=spec.IntervalOrRatioDataDescr(type="float32") + ), # Sixth input: the image embeddings (optional) spec.InputTensorDescr( diff --git a/micro_sam/bioimageio/predictor_adaptor.py b/micro_sam/bioimageio/predictor_adaptor.py index 509544f53..c4a53e641 100644 --- a/micro_sam/bioimageio/predictor_adaptor.py +++ b/micro_sam/bioimageio/predictor_adaptor.py @@ -78,6 +78,8 @@ def forward( self.sam.orig_h, self.sam.orig_w = self.sam.original_size self.sam.input_h, self.sam.input_w = self.sam.input_size + assert self.sam.is_image_set, "The predictor has not yet been initialized." + # Ensure input size and original size are set. self.sam.input_size = (self.sam.input_h, self.sam.input_w) self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) @@ -94,15 +96,16 @@ def forward( point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] point_labels = point_labels[0] - print() - print(point_coords.shape) - print(point_labels.shape) - print(boxes.shape) + if mask_prompts is None: + mask_input = None + else: + mask_input = mask_prompts[0] + masks, scores, _ = self.sam.predict_torch( point_coords=point_coords, point_labels=point_labels, boxes=boxes, - mask_input=mask_prompts, + mask_input=mask_input, multimask_output=False ) From c181e29f0446f76e6968756ef240fca0d23f48ee Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 9 Apr 2024 15:51:21 +0200 Subject: [PATCH 21/21] Bioimageio updates WIP --- micro_sam/bioimageio/model_export.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 65909a32a..b2d77423c 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -12,6 +12,7 @@ import xarray from bioimageio.spec import save_bioimageio_package +from bioimageio.core.digest_spec import create_sample_for_model from .. import util @@ -29,7 +30,11 @@ "cite": [ spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")), ], - "tags": ["segment-anything", "instance-segmentation"] + "tags": ["segment-anything", "instance-segmentation"], + # FIXME these are details for the uploader we should remove here + "uploader": spec.Uploader(email="constantin.pape@informatik.uni-goettinge.de"), + "id": "acclaimed-angelfish", + "id_emoji": "🐠", } @@ -40,7 +45,6 @@ def _create_test_inputs_and_outputs( checkpoint_path, tmp_dir, ): - # For now we just generate a single box prompt here, but we could also generate more input prompts. generator = PointAndBoxPromptGenerator( n_positive_points=1, @@ -196,21 +200,22 @@ def _check_model(model_description, input_paths, result_paths): # Check with all prompts. We only check the result for this setting, # because this was used to generate the test data. - prediction = pp.forward( + sample = create_sample_for_model( + model=model_description, image=image, box_prompts=box_prompts, point_prompts=point_prompts, point_labels=point_labels, mask_prompts=mask_prompts, embeddings=embeddings, - ) + ).as_single_block() + prediction = pp.predict_sample_block(sample) assert len(prediction) == 3 predicted_mask = prediction[0] assert np.allclose(mask, predicted_mask) - # FIXME this fails due to errors with optional inputs - # Check with partial prompts. + # Run the checks with partial prompts. prompt_kwargs = [ # With boxes. {"box_prompts": box_prompts}, @@ -227,7 +232,10 @@ def _check_model(model_description, input_paths, result_paths): ] for kwargs in prompt_kwargs: - prediction = pp.forward(image=image, embeddings=embeddings, **kwargs) + sample = create_sample_for_model( + model=model_description, image=image, embeddings=embeddings, **kwargs + ).as_single_block() + prediction = pp.predict_sample_block(sample) assert len(prediction) == 3 predicted_mask = prediction[0] assert predicted_mask.shape == mask.shape @@ -465,6 +473,9 @@ def export_sam_model( git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), tags=kwargs.get("tags", DEFAULTS["tags"]), covers=covers, + uploader=kwargs.get("uploader", DEFAULTS["uploader"]), + id=kwargs.get("id", DEFAULTS["id"]), + id_emoji=kwargs.get("id_emoji", DEFAULTS["id_emoji"]), # TODO attach the decoder weights if given # Can be list of files??? # attachments=[spec.FileDescr(source=file_path) for file_path in attachment_files] @@ -474,6 +485,6 @@ def export_sam_model( # config= ) - _check_model(model_description, input_paths, result_paths) + # _check_model(model_description, input_paths, result_paths) save_bioimageio_package(model_description, output_path=output_path)