|
| 1 | +import os |
| 2 | + |
| 3 | +import imageio.v3 as imageio |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from micro_sam.util import get_cache_directory |
| 7 | +from micro_sam.sample_data import fetch_livecell_example_data, fetch_wholeslide_example_data, fetch_3d_example_data |
| 8 | + |
| 9 | +from elf.io import open_file |
| 10 | + |
| 11 | + |
| 12 | +DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") |
| 13 | +EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings") |
| 14 | +os.makedirs(EMBEDDING_CACHE, exist_ok=True) |
| 15 | + |
| 16 | + |
| 17 | +def livecell_annotator(): |
| 18 | + from micro_sam.sam_annotator.object_classifier import object_classifier |
| 19 | + |
| 20 | + example_data = fetch_livecell_example_data(DATA_CACHE) |
| 21 | + image = imageio.imread(example_data) |
| 22 | + |
| 23 | + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr") |
| 24 | + model_type = "vit_b_lm" |
| 25 | + |
| 26 | + # This is the vit-b-lm segmentation |
| 27 | + segmentation = imageio.imread("./clf-test-data/livecell-test-seg.tif") |
| 28 | + |
| 29 | + object_classifier(image, segmentation, embedding_path=embedding_path, model_type=model_type) |
| 30 | + |
| 31 | + |
| 32 | +def wholeslide_annotator(): |
| 33 | + from micro_sam.sam_annotator.object_classifier import object_classifier |
| 34 | + |
| 35 | + example_data = fetch_wholeslide_example_data(DATA_CACHE) |
| 36 | + image = imageio.imread(example_data) |
| 37 | + |
| 38 | + embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr") |
| 39 | + model_type = "vit_b_lm" |
| 40 | + |
| 41 | + segmentation = imageio.imread("./clf-test-data/whole-slide-seg.tif") |
| 42 | + object_classifier( |
| 43 | + image, segmentation, embedding_path=embedding_path, model_type=model_type, |
| 44 | + tile_shape=(1024, 1024), halo=(256, 256), |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def lucchi_annotator(): |
| 49 | + from micro_sam.sam_annotator.object_classifier import object_classifier |
| 50 | + |
| 51 | + example_data = fetch_3d_example_data(DATA_CACHE) |
| 52 | + with open_file(example_data) as f: |
| 53 | + raw = f["*.png"][:] |
| 54 | + |
| 55 | + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_b_em_organelles.zarr") |
| 56 | + |
| 57 | + model_type = "vit_b_lm" |
| 58 | + segmentation = imageio.imread("./clf-test-data/lucchi-test-segmentation.tif") |
| 59 | + |
| 60 | + object_classifier(raw, segmentation, embedding_path=embedding_path, model_type=model_type) |
| 61 | + |
| 62 | + |
| 63 | +def tiled_3d_annotator(): |
| 64 | + from micro_sam.sam_annotator.object_classifier import object_classifier |
| 65 | + from skimage.data import cells3d |
| 66 | + |
| 67 | + data = cells3d()[30:34, 1] |
| 68 | + embed_path = "./clf-test-data/emebds-3d-tiled.zarr" |
| 69 | + |
| 70 | + model_type = "vit_b_lm" |
| 71 | + segmentation = imageio.imread("./clf-test-data/tiled-3d-segmentation.tif") |
| 72 | + |
| 73 | + object_classifier( |
| 74 | + data, segmentation, embedding_path=embed_path, model_type=model_type, |
| 75 | + tile_shape=(128, 128), halo=(32, 32) |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +def _get_livecell_data(): |
| 80 | + example_data = fetch_livecell_example_data(DATA_CACHE) |
| 81 | + image = imageio.imread(example_data) |
| 82 | + |
| 83 | + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr") |
| 84 | + |
| 85 | + # This is the vit-b-lm segmentation and a test annotaiton. |
| 86 | + segmentation = imageio.imread("./clf-test-data/livecell-test-seg.tif") |
| 87 | + annotations = imageio.imread("./clf-test-data/livecell-test-annotations.tif") |
| 88 | + |
| 89 | + model_type = "vit_b_lm" |
| 90 | + |
| 91 | + return image, segmentation, annotations, model_type, embedding_path, None, None |
| 92 | + |
| 93 | + |
| 94 | +def _get_wholeslide_data(): |
| 95 | + example_data = fetch_wholeslide_example_data(DATA_CACHE) |
| 96 | + image = imageio.imread(example_data) |
| 97 | + |
| 98 | + embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr") |
| 99 | + |
| 100 | + # This is the vit-b-lm segmentation and a test annotaiton. |
| 101 | + segmentation = imageio.imread("./clf-test-data/whole-slide-seg.tif") |
| 102 | + annotations = imageio.imread("./clf-test-data/wholeslide-annotations.tif") |
| 103 | + |
| 104 | + model_type = "vit_b_lm" |
| 105 | + |
| 106 | + return image, segmentation, annotations, model_type, embedding_path, (1024, 1024), (256, 256) |
| 107 | + |
| 108 | + |
| 109 | +def _get_lucchi_data(): |
| 110 | + example_data = fetch_3d_example_data(DATA_CACHE) |
| 111 | + with open_file(example_data) as f: |
| 112 | + raw = f["*.png"][:] |
| 113 | + |
| 114 | + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_b_em_organelles.zarr") |
| 115 | + |
| 116 | + segmentation = imageio.imread("./clf-test-data/lucchi-test-segmentation.tif") |
| 117 | + annotations = imageio.imread("./clf-test-data/lucchi-annotations.tif") |
| 118 | + |
| 119 | + model_type = "vit_b_em_organelles" |
| 120 | + |
| 121 | + return raw, segmentation, annotations, model_type, embedding_path, None, None |
| 122 | + |
| 123 | + |
| 124 | +def _get_3d_tiled_data(): |
| 125 | + from skimage.data import cells3d |
| 126 | + |
| 127 | + data = cells3d()[30:34, 1] |
| 128 | + embed_path = "./clf-test-data/emebds-3d-tiled.zarr" |
| 129 | + model_type = "vit_b_lm" |
| 130 | + |
| 131 | + segmentation = imageio.imread("./clf-test-data/tiled-3d-segmentation.tif") |
| 132 | + annotations = imageio.imread("./clf-test-data/tiled-3d-annotations.tif") |
| 133 | + |
| 134 | + return data, segmentation, annotations, model_type, embed_path, (128, 128), (32, 32) |
| 135 | + |
| 136 | + |
| 137 | +def annotator_devel(): |
| 138 | + from micro_sam import object_classification as core_clf |
| 139 | + from micro_sam.sam_annotator import object_classifier as clf |
| 140 | + from micro_sam.util import precompute_image_embeddings, get_sam_model |
| 141 | + |
| 142 | + # image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_livecell_data() |
| 143 | + # image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_wholeslide_data() |
| 144 | + # image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_lucchi_data() |
| 145 | + image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_3d_tiled_data() |
| 146 | + |
| 147 | + # 1. Get the SAM model |
| 148 | + predictor = get_sam_model(model_type) |
| 149 | + # 2. Precompute the image embeddings. |
| 150 | + image_embeddings = precompute_image_embeddings( |
| 151 | + predictor, image, save_path=embedding_path, tile_shape=tile_shape, halo=halo |
| 152 | + ) |
| 153 | + # 3. Get the segmentation ids and the extracted features for the segmentations. |
| 154 | + seg_ids, features = core_clf.compute_object_features(image_embeddings, segmentation) |
| 155 | + # 4. Points to the objects we would like to select for training RF. |
| 156 | + labels = clf._accumulate_labels(segmentation, annotations) |
| 157 | + # 5. Traint the RF model. |
| 158 | + rf = clf._train_rf(features, labels, n_estimators=200, max_depth=10) |
| 159 | + # 6. Run the trained RF prediction on new images. |
| 160 | + object_prediction = rf.predict(features) |
| 161 | + # 7. Map the predictions back to the instance segmentation. |
| 162 | + prediction = core_clf.project_prediction_to_segmentation(segmentation, object_prediction, seg_ids) |
| 163 | + |
| 164 | + import napari |
| 165 | + v = napari.Viewer() |
| 166 | + v.add_image(image) |
| 167 | + v.add_labels(annotations) |
| 168 | + v.add_labels(prediction) |
| 169 | + napari.run() |
| 170 | + |
| 171 | + |
| 172 | +def create_3d_data_with_tiling(): |
| 173 | + from skimage.data import cells3d |
| 174 | + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter |
| 175 | + |
| 176 | + predictor, segmenter = get_predictor_and_segmenter(model_type="vit_b_lm", is_tiled=True) |
| 177 | + data = cells3d()[30:34, 1] |
| 178 | + |
| 179 | + embed_path = "./clf-test-data/emebds-3d-tiled.zarr" |
| 180 | + seg = automatic_instance_segmentation( |
| 181 | + predictor, segmenter, data, embedding_path=embed_path, ndim=3, tile_shape=(128, 128), halo=(32, 32) |
| 182 | + ) |
| 183 | + |
| 184 | + import napari |
| 185 | + v = napari.Viewer() |
| 186 | + v.add_image(data) |
| 187 | + v.add_labels(seg) |
| 188 | + # For annotations. |
| 189 | + v.add_labels(np.zeros_like(seg)) |
| 190 | + napari.run() |
| 191 | + |
| 192 | + |
| 193 | +def histopathology_annotator(): |
| 194 | + from torch_em.data.datasets.histopathology.lynsec import get_lynsec_paths |
| 195 | + from micro_sam.sam_annotator import object_classifier as clf |
| 196 | + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter |
| 197 | + |
| 198 | + predictor, segmenter = get_predictor_and_segmenter(model_type="vit_b_histopathology") |
| 199 | + |
| 200 | + image_paths, _ = get_lynsec_paths(path="./clf-test-data/nuclick", choice="ihc", download=True) |
| 201 | + image_paths = image_paths[:10] |
| 202 | + |
| 203 | + images, segmentations = [], [] |
| 204 | + embedding_paths = [] |
| 205 | + |
| 206 | + for i, image_path in enumerate(image_paths): |
| 207 | + image = imageio.imread(image_path) |
| 208 | + embedding_path = f"./clf-test-data/embeddings_nuclick_{i}.zarr" |
| 209 | + seg_path = f"./clf-test-data/seg-nuclick_{i}.tif" |
| 210 | + |
| 211 | + if os.path.exists(seg_path): |
| 212 | + segmentation = imageio.imread(seg_path) |
| 213 | + else: |
| 214 | + segmentation = automatic_instance_segmentation( |
| 215 | + predictor, segmenter, embedding_path=embedding_path, input_path=image, ndim=2, |
| 216 | + ) |
| 217 | + imageio.imwrite(seg_path, segmentation, compression="zlib") |
| 218 | + |
| 219 | + images.append(image) |
| 220 | + segmentations.append(segmentation) |
| 221 | + embedding_paths.append(embedding_path) |
| 222 | + |
| 223 | + clf.image_series_object_classifier( |
| 224 | + images, segmentations, output_folder="./clf-test-data/histo-results", |
| 225 | + embedding_paths=embedding_paths, model_type="vit_b_histopathology", ndim=2, |
| 226 | + ) |
| 227 | + |
| 228 | + |
| 229 | +def batch_prediction(): |
| 230 | + import napari |
| 231 | + from torch_em.data.datasets.histopathology.lynsec import get_lynsec_paths |
| 232 | + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter |
| 233 | + from micro_sam.object_classification import run_prediction_with_object_classifier |
| 234 | + from tqdm import tqdm |
| 235 | + |
| 236 | + predictor, segmenter = get_predictor_and_segmenter(model_type="vit_b_histopathology") |
| 237 | + |
| 238 | + image_paths, _ = get_lynsec_paths(path="./clf-test-data/nuclick", choice="ihc", download=True) |
| 239 | + # Test the batch prediction on the next 5 images. |
| 240 | + image_paths = image_paths[10:12] |
| 241 | + |
| 242 | + images, segmentations = [], [] |
| 243 | + # Prepare images and segmentations |
| 244 | + for image_path in tqdm(image_paths, desc="Segment images"): |
| 245 | + image = imageio.imread(image_path) |
| 246 | + segmentation = automatic_instance_segmentation(predictor, segmenter, input_path=image, ndim=2, verbose=False) |
| 247 | + images.append(image) |
| 248 | + segmentations.append(segmentation) |
| 249 | + |
| 250 | + rf_path = "clf-test-data/histo-results/rf.joblib" |
| 251 | + print("Start object clf") |
| 252 | + predictions = run_prediction_with_object_classifier(images, segmentations, predictor, rf_path, ndim=2) |
| 253 | + |
| 254 | + for im, seg, pred in zip(images, segmentations, predictions): |
| 255 | + v = napari.Viewer() |
| 256 | + v.add_image(im) |
| 257 | + v.add_labels(seg) |
| 258 | + v.add_labels(pred) |
| 259 | + napari.run() |
| 260 | + |
| 261 | + |
| 262 | +def main(): |
| 263 | + # create_3d_data_with_tiling() |
| 264 | + |
| 265 | + # livecell_annotator() |
| 266 | + # wholeslide_annotator() |
| 267 | + # lucchi_annotator() |
| 268 | + # tiled_3d_annotator() |
| 269 | + histopathology_annotator() |
| 270 | + # batch_prediction() |
| 271 | + |
| 272 | + # annotator_devel() |
| 273 | + |
| 274 | + |
| 275 | +if __name__ == "__main__": |
| 276 | + main() |
0 commit comments