| 
 | 1 | +import argparse  | 
 | 2 | +import os  | 
 | 3 | +import subprocess  | 
 | 4 | + | 
 | 5 | +import cryoet_data_portal as cdp  | 
 | 6 | +import numpy as np  | 
 | 7 | +import zarr  | 
 | 8 | + | 
 | 9 | +from ome_zarr.io import parse_url  | 
 | 10 | +from ome_zarr.writer import write_image  | 
 | 11 | +from synapse_net.file_utils import read_data_from_cryo_et_portal_run  | 
 | 12 | +from synapse_net.inference.vesicles import segment_vesicles  | 
 | 13 | +from tqdm import tqdm  | 
 | 14 | + | 
 | 15 | +# OUTPUT_ROOT = ""  | 
 | 16 | +OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/synaptic-reconstruction/portal"  | 
 | 17 | + | 
 | 18 | + | 
 | 19 | +def get_tomograms(deposition_id, processing_type, number_of_tomograms=None):  | 
 | 20 | +    client = cdp.Client()  | 
 | 21 | +    tomograms = cdp.Tomogram.find(  | 
 | 22 | +        client, [cdp.Tomogram.deposition_id == deposition_id, cdp.Tomogram.processing == processing_type]  | 
 | 23 | +    )  | 
 | 24 | +    if number_of_tomograms is not None:  | 
 | 25 | +        tomograms = tomograms[:number_of_tomograms]  | 
 | 26 | +    return tomograms  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +def write_ome_zarr(output_file, segmentation, voxel_size, unit="nanometer"):  | 
 | 30 | +    store = parse_url(output_file, mode="w").store  | 
 | 31 | +    root = zarr.group(store=store)  | 
 | 32 | + | 
 | 33 | +    scale = list(voxel_size.values())  | 
 | 34 | +    trafo = [  | 
 | 35 | +        [{"scale": scale, "type": "scale"}]  | 
 | 36 | +    ]  | 
 | 37 | +    axes = [  | 
 | 38 | +        {"name": "z", "type": "space", "unit": unit},  | 
 | 39 | +        {"name": "y", "type": "space", "unit": unit},  | 
 | 40 | +        {"name": "x", "type": "space", "unit": unit},  | 
 | 41 | +    ]  | 
 | 42 | +    write_image(segmentation, root, axes=axes, coordinate_transformations=trafo, scaler=None)  | 
 | 43 | + | 
 | 44 | + | 
 | 45 | +def run_prediction(tomogram, deposition_id, processing_type):  | 
 | 46 | +    output_folder = os.path.join(OUTPUT_ROOT, f"upload_CZCDP-{deposition_id}", str(tomogram.run.dataset_id))  | 
 | 47 | +    os.makedirs(output_folder, exist_ok=True)  | 
 | 48 | + | 
 | 49 | +    output_file = os.path.join(output_folder, f"{tomogram.run.name}.zarr")  | 
 | 50 | +    # We don't need to do anything if this file is already processed.  | 
 | 51 | +    if os.path.exists(output_file):  | 
 | 52 | +        return  | 
 | 53 | + | 
 | 54 | +    # Read tomogram data on the fly.  | 
 | 55 | +    data, voxel_size = read_data_from_cryo_et_portal_run(  | 
 | 56 | +        tomogram.run_id, processing_type=processing_type  | 
 | 57 | +    )  | 
 | 58 | + | 
 | 59 | +    # Segment vesicles.  | 
 | 60 | +    model_path = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-portal-v3"  | 
 | 61 | +    scale = (1.0 / 2.7,) * 3  | 
 | 62 | +    segmentation = segment_vesicles(data, model_path=model_path, scale=scale)  | 
 | 63 | + | 
 | 64 | +    # Save the segmentation.  | 
 | 65 | +    write_ome_zarr(output_file, segmentation, voxel_size)  | 
 | 66 | + | 
 | 67 | + | 
 | 68 | +# TODO download on lower scale  | 
 | 69 | +def check_result(tomogram, deposition_id, processing_type):  | 
 | 70 | +    import napari  | 
 | 71 | + | 
 | 72 | +    # Read tomogram data on the fly.  | 
 | 73 | +    print("Download data ...")  | 
 | 74 | +    data, voxel_size = read_data_from_cryo_et_portal_run(  | 
 | 75 | +        tomogram.run_id, processing_type=processing_type  | 
 | 76 | +    )  | 
 | 77 | + | 
 | 78 | +    # Read the output file if it exists.  | 
 | 79 | +    output_folder = os.path.join(f"upload_CZCDP-{deposition_id}", str(tomogram.run.dataset_id))  | 
 | 80 | +    output_file = os.path.join(output_folder, f"{tomogram.run.name}.zarr")  | 
 | 81 | +    if os.path.exists(output_file):  | 
 | 82 | +        with zarr.open(output_file, "r") as f:  | 
 | 83 | +            segmentation = f["0"][:]  | 
 | 84 | +    else:  | 
 | 85 | +        segmentation = None  | 
 | 86 | + | 
 | 87 | +    v = napari.Viewer()  | 
 | 88 | +    v.add_image(data)  | 
 | 89 | +    if segmentation is not None:  | 
 | 90 | +        v.add_labels(segmentation)  | 
 | 91 | +    napari.run()  | 
 | 92 | + | 
 | 93 | + | 
 | 94 | +def _get_task_tomograms(tomograms, slurm_tasks, task_id):  | 
 | 95 | +    # TODO we could also filter already done tomos.  | 
 | 96 | +    inputs_to_tasks = np.array_split(tomograms, slurm_tasks)  | 
 | 97 | +    assert len(inputs_to_tasks) == slurm_tasks  | 
 | 98 | +    return inputs_to_tasks[task_id]  | 
 | 99 | + | 
 | 100 | + | 
 | 101 | +def process_slurm(args, tomograms, deposition_id, processing_type):  | 
 | 102 | +    assert not args.check  | 
 | 103 | +    task_id = os.environ.get("SLURM_ARRAY_TASK_ID")  | 
 | 104 | + | 
 | 105 | +    if task_id is None:  # We are not in the slurm task and submit the job.  | 
 | 106 | +        # Assemble the command for submitting a slurm array job.  | 
 | 107 | +        script_path = "process_tomograms_on_the_fly.sbatch"  | 
 | 108 | +        cmd = ["sbatch", "-a", f"0-{args.slurm_tasks-1}", script_path, "-s", str(args.slurm_tasks)]  | 
 | 109 | +        print("Submitting to slurm:")  | 
 | 110 | +        print(cmd)  | 
 | 111 | +        subprocess.run(cmd)  | 
 | 112 | +    else:  # We are in the task.  | 
 | 113 | +        task_id = int(task_id)  | 
 | 114 | +        this_tomograms = _get_task_tomograms(tomograms, args.slurm_tasks, task_id)  | 
 | 115 | +        for tomogram in tqdm(this_tomograms, desc="Run prediction for tomograms on-the-fly"):  | 
 | 116 | +            run_prediction(tomogram, deposition_id, processing_type)  | 
 | 117 | + | 
 | 118 | + | 
 | 119 | +def process_local(args, tomograms, deposition_id, processing_type):  | 
 | 120 | +    # Process each tomogram.  | 
 | 121 | +    print("Start processing", len(tomograms), "tomograms")  | 
 | 122 | +    for tomogram in tqdm(tomograms, desc="Run prediction for tomograms on-the-fly"):  | 
 | 123 | +        if args.check:  | 
 | 124 | +            check_result(tomogram, deposition_id, processing_type)  | 
 | 125 | +        else:  | 
 | 126 | +            run_prediction(tomogram, deposition_id, processing_type)  | 
 | 127 | + | 
 | 128 | + | 
 | 129 | +def main():  | 
 | 130 | +    parser = argparse.ArgumentParser()  | 
 | 131 | +    # Whether to check the result with napari instead of running the prediction.  | 
 | 132 | +    parser.add_argument("-c", "--check", action="store_true")  | 
 | 133 | +    parser.add_argument("-n", "--number_of_tomograms", type=int, default=None)  | 
 | 134 | +    parser.add_argument("-s", "--slurm_tasks", type=int, default=None)  | 
 | 135 | +    args = parser.parse_args()  | 
 | 136 | + | 
 | 137 | +    deposition_id = 10313  | 
 | 138 | +    processing_type = "denoised"  | 
 | 139 | + | 
 | 140 | +    # Get all the (processed) tomogram ids in the deposition.  | 
 | 141 | +    tomograms = get_tomograms(deposition_id, processing_type, args.number_of_tomograms)  | 
 | 142 | + | 
 | 143 | +    if args.slurm_tasks is None:  | 
 | 144 | +        process_local(args, tomograms, deposition_id, processing_type)  | 
 | 145 | +    else:  | 
 | 146 | +        process_slurm(args, tomograms, deposition_id, processing_type)  | 
 | 147 | + | 
 | 148 | + | 
 | 149 | +if __name__ == "__main__":  | 
 | 150 | +    main()  | 
0 commit comments