11import multiprocessing as mp
22import os
3+ import sys
34import warnings
45from concurrent import futures
56
1011import vigra
1112import torch
1213import z5py
14+ import zarr
1315import json
1416
1517from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
1820from torch_em .util import load_model
1921from torch_em .util .prediction import predict_with_halo
2022from tqdm import tqdm
23+ from inspect import getsourcefile
24+
25+ sys .path .append (os .path .join (os .path .dirname (os .path .dirname (os .path .dirname (getsourcefile (lambda :0 )))), "scripts" , "prediction" ))
26+ import upload_to_s3
2127
2228"""
2329Prediction using distance U-Net.
@@ -43,7 +49,7 @@ def ndim(self):
4349 return self ._volume .ndim - 1
4450
4551
46- def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances = 1 , slurm_task_id = 0 , mean = None , std = None ):
52+ def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances = 1 , slurm_task_id = 0 , mean = None , std = None , s3 = None ):
4753 with warnings .catch_warnings ():
4854 warnings .simplefilter ("ignore" )
4955 if os .path .isdir (model_path ):
@@ -56,6 +62,9 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
5662
5763 if input_key is None :
5864 input_ = imageio .imread (input_path )
65+ elif s3 is not None :
66+ with zarr .open (input_path , mode = "r" ) as f :
67+ input_ = f [input_key ]
5968 else :
6069 input_ = open_file (input_path , "r" )[input_key ]
6170
@@ -138,7 +147,7 @@ def postprocess(x):
138147 return original_shape
139148
140149
141- def find_mask (input_path , input_key , output_folder ):
150+ def find_mask (input_path , input_key , output_folder , s3 = None ):
142151 mask_path = os .path .join (output_folder , "mask.zarr" )
143152 f = z5py .File (mask_path , "a" )
144153
@@ -149,6 +158,10 @@ def find_mask(input_path, input_key, output_folder):
149158 if input_key is None :
150159 raw = imageio .imread (input_path )
151160 chunks = (64 , 64 , 64 )
161+ elif s3 is not None :
162+ with zarr .open (input_path , mode = "r" ) as fin :
163+ raw = fin [input_key ]
164+ chunks = raw .chunks
152165 else :
153166 fin = open_file (input_path , "r" )
154167 raw = fin [input_key ]
@@ -243,7 +256,10 @@ def write_block(block_id):
243256 tp .map (write_block , range (blocking .numberOfBlocks ))
244257
245258
246- def calc_mean_and_std (input_path , input_key , output_folder ):
259+ def calc_mean_and_std (
260+ input_path , input_key , output_folder ,
261+ s3 = None ,
262+ ):
247263 """
248264 Calculate mean and standard deviation of full volume.
249265 Parameters are saved in 'mean_std.json' within the output folder.
@@ -254,6 +270,9 @@ def calc_mean_and_std(input_path, input_key, output_folder):
254270
255271 if input_key is None :
256272 input_ = imageio .imread (input_path )
273+ elif s3 is not None :
274+ with zarr .open (input_path , mode = "r" ) as f :
275+ input_ = f [input_key ]
257276 else :
258277 input_ = open_file (input_path , "r" )[input_key ]
259278
@@ -267,6 +286,7 @@ def calc_mean_and_std(input_path, input_key, output_folder):
267286 with open (json_file , "w" ) as f :
268287 json .dump (ddict , f )
269288
289+
270290def run_unet_prediction (
271291 input_path , input_key ,
272292 output_folder , model_path ,
@@ -288,32 +308,63 @@ def run_unet_prediction(
288308
289309def run_unet_prediction_preprocess_slurm (
290310 input_path , input_key , output_folder ,
311+ s3 = None , s3_bucket_name = None , s3_service_endpoint = None , s3_credentials = None ,
291312):
292313 """
293314 Pre-processing for the parallel prediction with U-Net models.
294315 Masks are stored in mask.zarr in the output folder.
295316 The mean and standard deviation are precomputed for later usage during prediction
296- and stored in a JSON file within the output folder as mean_std.json
317+ and stored in a JSON file within the output folder as mean_std.json.
297318 """
298- find_mask (input_path , input_key , output_folder )
299- calc_mean_and_std (input_path , input_key , output_folder )
319+ if s3 is not None :
320+ bucket_name , service_endpoint , credentials = upload_to_s3 .check_s3_credentials (s3_bucket_name , s3_service_endpoint , s3_credentials )
321+
322+ input_path , fs = upload_to_s3 .get_s3_path (input_path , bucket_name = bucket_name , service_endpoint = service_endpoint , credential_file = credentials )
323+
324+ if not os .path .isdir (os .path .join (output_folder , "mask.zarr" )):
325+ find_mask (input_path , input_key , output_folder , s3 = s3 )
326+
327+ calc_mean_and_std (input_path , input_key , output_folder , s3 = s3 )
328+
300329
301330def run_unet_prediction_slurm (
302331 input_path , input_key , output_folder , model_path ,
303332 scale = None ,
304333 block_shape = None , halo = None , prediction_instances = 1 ,
334+ s3 = None , s3_bucket_name = None , s3_service_endpoint = None , s3_credentials = None ,
305335):
336+ """
337+ Run prediction of distance U-Net for data stored locally or on an S3 bucket.
338+
339+ :param str input_path: File path to input data
340+ :param str input_key: Input key for data in ome.zarr format
341+ :param str output_folder: Output folder for prediction.zarr
342+ :param str model_path: File path to distance U-Net model
343+ :param float scale:
344+ :param tuple block_shape:
345+ :param tuple halo:
346+ :param int prediction_instances: Number of workers for parallel computation within slurm array
347+ :param bool s3: Flag for accessing data on S3 bucket
348+ :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported
349+ :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported
350+ :param str s3_credentials: Path to file containing S3 credentials
351+ """
306352 os .makedirs (output_folder , exist_ok = True )
307353 prediction_instances = int (prediction_instances )
308354 slurm_task_id = os .environ .get ("SLURM_ARRAY_TASK_ID" )
309355
356+ if s3 is not None :
357+ bucket_name , service_endpoint , credentials = upload_to_s3 .check_s3_credentials (s3_bucket_name , s3_service_endpoint , s3_credentials )
358+
359+ input_path , fs = upload_to_s3 .get_s3_path (input_path , bucket_name = bucket_name , service_endpoint = service_endpoint , credential_file = credentials )
360+
310361 if slurm_task_id is not None :
311362 slurm_task_id = int (slurm_task_id )
312363 else :
313364 raise ValueError ("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH." )
314365
315366 if not os .path .isdir (os .path .join (output_folder , "mask.zarr" )):
316- find_mask (input_path , input_key , output_folder )
367+ find_mask (input_path , input_key , output_folder , s3 = s3 )
317368
318369 # get pre-computed mean and standard deviation of full volume from JSON file
319370 if os .path .isfile (os .path .join (output_folder , "mean_std.json" )):
@@ -328,9 +379,10 @@ def run_unet_prediction_slurm(
328379 original_shape = prediction_impl (
329380 input_path , input_key , output_folder , model_path , scale , block_shape , halo ,
330381 prediction_instances = prediction_instances , slurm_task_id = slurm_task_id ,
331- mean = mean , std = std ,
382+ mean = mean , std = std , s3 = s3 ,
332383 )
333384
385+
334386# does NOT need GPU, FIXME: only run on CPU
335387def run_unet_segmentation_slurm (output_folder , min_size ):
336388 min_size = int (min_size )
0 commit comments