|
| 1 | +import napari |
| 2 | +from elf.io import open_file |
| 3 | +import h5py |
| 4 | +import os |
| 5 | +import torch |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +import micro_sam.sam_3d_wrapper as sam_3d |
| 9 | +import micro_sam.util as util |
| 10 | +# from micro_sam.segment_instances import ( |
| 11 | +# segment_instances_from_embeddings, |
| 12 | +# segment_instances_sam, |
| 13 | +# segment_instances_from_embeddings_3d, |
| 14 | +# ) |
| 15 | +from micro_sam import multi_dimensional_segmentation as mds |
| 16 | +from micro_sam.visualization import compute_pca |
| 17 | +INPUT_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5" |
| 18 | +# EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/embedding-mito-3d.zarr" |
| 19 | +EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/usr/nimlufre/" |
| 20 | +INPUT_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5" |
| 21 | +EMBEDDINGS_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/" |
| 22 | +INPUT_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/4007_cutout_1.h5" |
| 23 | +EMBEDDINGS_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/embedding-mito-3d.zarr" |
| 24 | +TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01" |
| 25 | +EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr" |
| 26 | + |
| 27 | +# def cell_segmentation_3d() -> None: |
| 28 | +# with open_file(TIMESERIES_PATH, mode="r") as f: |
| 29 | +# timeseries = f["*.tif"][:50] |
| 30 | + |
| 31 | +# predictor = util.get_sam_model() |
| 32 | +# image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH) |
| 33 | + |
| 34 | +# seg = segment_instances_from_embeddings_3d(predictor, image_embeddings) |
| 35 | + |
| 36 | +# v = napari.Viewer() |
| 37 | +# v.add_image(timeseries) |
| 38 | +# v.add_labels(seg) |
| 39 | +# napari.run() |
| 40 | + |
| 41 | + |
| 42 | +# def _get_dataset_and_reshape(path: str, key: str = "raw", shape: tuple = (32, 256, 256)) -> np.ndarray: |
| 43 | + |
| 44 | +# with h5py.File(path, "r") as f: |
| 45 | +# # Check if the key exists in the file |
| 46 | +# if key not in f: |
| 47 | +# raise KeyError(f"Dataset with key '{key}' not found in file '{path}'.") |
| 48 | + |
| 49 | +# # Load the dataset |
| 50 | +# dataset = f[key][...] |
| 51 | + |
| 52 | +# # Reshape the dataset |
| 53 | +# if dataset.shape != shape: |
| 54 | +# try: |
| 55 | +# # Attempt to reshape the dataset to the desired shape |
| 56 | +# dataset = dataset.reshape(shape) |
| 57 | +# except ValueError: |
| 58 | +# raise ValueError(f"Failed to reshape dataset with key '{key}' to shape {shape}.") |
| 59 | + |
| 60 | +# return dataset |
| 61 | +def get_dataset_cutout(path: str, key: str = "raw", shape: tuple = (32, 256, 256), |
| 62 | + start_index: tuple = (0, 0, 0)) -> np.ndarray: |
| 63 | + """ |
| 64 | + Loads a cutout from a dataset in an HDF5 file. |
| 65 | +
|
| 66 | + Args: |
| 67 | + path (str): Path to the HDF5 file. |
| 68 | + key (str, optional): Key of the dataset to load. Defaults to "raw". |
| 69 | + shape (tuple, optional): Desired shape of the cutout. Defaults to (32, 256, 256). |
| 70 | + start_index (tuple, optional): Starting index for the cutout within the dataset. |
| 71 | + Defaults to None, which selects a random starting point within valid bounds. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + np.ndarray: The loaded cutout of the dataset with the specified shape. |
| 75 | +
|
| 76 | + Raises: |
| 77 | + KeyError: If the specified key is not found in the HDF5 file. |
| 78 | + ValueError: If the cutout shape exceeds the dataset dimensions or the starting index is invalid. |
| 79 | + """ |
| 80 | + |
| 81 | + with h5py.File(path, "r") as f: |
| 82 | + |
| 83 | + dataset = f[key] |
| 84 | + dataset_shape = dataset.shape |
| 85 | + print("original data shape", dataset_shape) |
| 86 | + |
| 87 | + # Validate cutout shape |
| 88 | + if any(s > d for s, d in zip(shape, dataset_shape)): |
| 89 | + raise ValueError(f"Cutout shape {shape} exceeds dataset dimensions {dataset_shape}.") |
| 90 | + |
| 91 | + # Generate random starting index if not provided |
| 92 | + if start_index is None: |
| 93 | + start_index = tuple(np.random.randint(0, dim - s + 1, size=len(shape)) for dim, s in zip(dataset_shape, shape)) |
| 94 | + |
| 95 | + # Calculate end index |
| 96 | + end_index = tuple(min(i + s, dim) for i, s, dim in zip(start_index, shape, dataset_shape)) |
| 97 | + |
| 98 | + # Load the cutout |
| 99 | + cutout = dataset[start_index[0]:end_index[0], |
| 100 | + start_index[1]:end_index[1], |
| 101 | + start_index[2]:end_index[2]] |
| 102 | + print("cutout data shape", cutout.shape) |
| 103 | + |
| 104 | + return cutout |
| 105 | + |
| 106 | + |
| 107 | +def mito_segmentation_3d() -> None: |
| 108 | + patch_shape = (32, 256, 256) |
| 109 | + start_index = (10, 32, 64) |
| 110 | + data_slice = get_dataset_cutout(INPUT_PATH_LOCAL, shape=patch_shape) #start_index=start_index |
| 111 | + |
| 112 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 113 | + model_type = "vit_b" |
| 114 | + predictor, sam = util.get_sam_model(return_sam=True, model_type=model_type, device=device) |
| 115 | + |
| 116 | + d_size = 3 |
| 117 | + predictor3d = sam_3d.Predictor3D(sam, d_size) |
| 118 | + print(predictor3d) |
| 119 | + #breakpoint() |
| 120 | + predictor3d.model.forward(torch.from_numpy(data_slice), multimask_output=False, image_size=patch_shape) |
| 121 | + # output = predictor3d.model([data_slice], multimask_output=False)#image_size=patch_shape |
| 122 | + |
| 123 | + # predictor3d._hash = util.models().registry[model_type] |
| 124 | + |
| 125 | + # predictor3d.model_name = model_type |
| 126 | + |
| 127 | + # image_embeddings = util.precompute_image_embeddings(predictor3d, volume, EMBEDDINGS_PATH_CLUSTER) |
| 128 | + # seg = util.segment_instances_from_embeddings_3d(predictor3d, image_embeddings) |
| 129 | + |
| 130 | + # prediction_filename = os.path.join(EMBEDDINGS_PATH_CLUSTER, f"prediction_{INPUT_PATH_CLUSTER}.h5") |
| 131 | + # with h5py.File(prediction_filename, "w") as prediction_file: |
| 132 | + # prediction_file.create_dataset("prediction", data=seg) |
| 133 | + |
| 134 | + # visualize |
| 135 | + # v = napari.Viewer() |
| 136 | + # v.add_image(volume) |
| 137 | + # v.add_labels(seg) |
| 138 | + # v.add_labels(seg_sam) |
| 139 | + # napari.run() |
| 140 | + |
| 141 | + |
| 142 | + |
| 143 | +def main(): |
| 144 | + # automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py') |
| 145 | + # nucleus_segmentation(use_mws=True) |
| 146 | + mito_segmentation_3d() |
| 147 | + |
| 148 | + # automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py') |
| 149 | + # cell_segmentation(use_mws=True) |
| 150 | + # cell_segmentation_3d() |
| 151 | + |
| 152 | + |
| 153 | +if __name__ == "__main__": |
| 154 | + main() |
0 commit comments