Skip to content

Commit 6bf43e3

Browse files
committed
Improved S3 functions, fixed issues
1 parent 4394134 commit 6bf43e3

File tree

6 files changed

+66
-66
lines changed

6 files changed

+66
-66
lines changed

flamingo_tools/s3_utils.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,74 @@
33
import s3fs
44
import zarr
55

6-
from tqdm import tqdm
6+
"""
7+
This script contains utility functions for processing data located on an S3 storage.
8+
The upload of data to the storage system should be performed with 'rclone'.
9+
"""
710

8-
# Using incucyte s3 as a temporary measure.
9-
MOBIE_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/moser/lightsheet/mobie"
11+
# Dedicated bucket for cochlea lightsheet project
12+
MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet"
1013
SERVICE_ENDPOINT = "https://s3.gwdg.de/"
11-
BUCKET_NAME = "incucyte-general/lightsheet"
14+
BUCKET_NAME = "cochlea-lightsheet"
15+
16+
DEFAULT_CREDENTIALS = os.path.expanduser("~/.aws/credentials")
1217

1318
# For MoBIE:
1419
# https://s3.gwdg.de/incucyte-general/lightsheet
1520

16-
def check_s3_credentials(bucket_name, service_endpoint, credentials):
21+
def check_s3_credentials(bucket_name, service_endpoint, credential_file):
1722
"""
1823
Check if S3 parameter and credentials were set either as a function input or were exported as environment variables.
1924
"""
2025
if bucket_name is None:
2126
bucket_name = os.getenv('BUCKET_NAME')
2227
if bucket_name is None:
23-
raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name <bucket_name>\nexport BUCKET_NAME=<bucket_name>")
28+
if BUCKET_NAME in globals():
29+
bucket_name = BUCKET_NAME
30+
else:
31+
raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name <bucket_name>\nexport BUCKET_NAME=<bucket_name>")
2432

2533
if service_endpoint is None:
2634
service_endpoint = os.getenv('SERVICE_ENDPOINT')
2735
if service_endpoint is None:
28-
raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint <endpoint>\nexport SERVICE_ENDPOINT=<endpoint>")
36+
if SERVICE_ENDPOINT in globals():
37+
service_endpoint = SERVICE_ENDPOINT
38+
else:
39+
raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint <endpoint>\nexport SERVICE_ENDPOINT=<endpoint>")
2940

30-
if credentials is None:
41+
if credential_file is None:
3142
access_key = os.getenv('AWS_ACCESS_KEY_ID')
3243
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
44+
45+
# check for default credentials if no credential_file is provided
3346
if access_key is None:
34-
raise ValueError("Either provide a credential file as an optional argument or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=<access_key>")
47+
if os.path.isfile(DEFAULT_CREDENTIALS):
48+
access_key, _ = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS)
49+
else:
50+
raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=<access_key>")
3551
if secret_key is None:
36-
raise ValueError("Either provide a credential file as an optional argument or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=<secret_key>")
52+
# check for default credentials
53+
if os.path.isfile(DEFAULT_CREDENTIALS):
54+
_, secret_key = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS)
55+
else:
56+
raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=<secret_key>")
3757

38-
return bucket_name, service_endpoint, credentials
58+
else:
59+
# check validity of credential file
60+
_, _ = read_s3_credentials(credential_file=credential_file)
3961

62+
return bucket_name, service_endpoint, credential_file
4063

4164
def get_s3_path(
4265
input_path,
43-
bucket_name, service_endpoint,
66+
bucket_name=None, service_endpoint=None,
4467
credential_file=None,
4568
):
4669
"""
4770
Get S3 path for a file or folder and file system based on S3 parameters and credentials.
4871
"""
72+
bucket_name, service_endpoint, credential_file = check_s3_credentials(bucket_name, service_endpoint, credential_file)
73+
4974
fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file)
5075

5176
zarr_path=f"{bucket_name}/{input_path}"
@@ -84,24 +109,3 @@ def create_s3_target(url, anon=False, credential_file=None):
84109
else:
85110
fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs)
86111
return fs
87-
88-
89-
def upload_data():
90-
target = create_s3_target(
91-
SERVICE_ENDPOINT,
92-
credential_file="./credentials.incucyte"
93-
)
94-
to_upload = []
95-
for root, dirs, files in os.walk(MOBIE_FOLDER):
96-
dirs.sort()
97-
for ff in files:
98-
if ff.endswith(".xml"):
99-
to_upload.append(os.path.join(root, ff))
100-
101-
print("Uploading", len(to_upload), "files to")
102-
103-
for path in tqdm(to_upload):
104-
rel_path = os.path.relpath(path, MOBIE_FOLDER)
105-
target.put(
106-
path, os.path.join(BUCKET_NAME, rel_path)
107-
)

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import z5py
1414
import zarr
15+
import tifffile
1516
import json
1617

1718
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
@@ -59,15 +60,18 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
5960
image_mask = z5py.File(mask_path, "r")["mask"]
6061

6162
if input_key is None:
62-
input_ = imageio.imread(input_path)
63-
chunks = (64, 64, 64)
64-
elif s3 is not None:
63+
try:
64+
input_ = tifffile.memmap(input_path, mode="r")
65+
except ValueError:
66+
print(f"Could not memmap the data from {input_path}. Fall back to load it into memory.")
67+
input_ = imageio.imread(input_path)
68+
elif isinstance(input_path, str):
69+
input_ = open_file(input_path, "r")[input_key]
70+
else:
6571
with zarr.open(input_path, mode="r") as f:
6672
input_ = f[input_key]
67-
chunks = input_.chunks()
68-
else:
69-
input_ = open_file(input_path, "r")[input_key]
70-
chunks = (64, 64, 64)
73+
74+
chunks = getattr(input_, "chunks", (64,64,64))
7175

7276
if scale is None or scale == 1:
7377
original_shape = None
@@ -157,16 +161,19 @@ def find_mask(input_path, input_key, output_folder, s3=None):
157161
return
158162

159163
if input_key is None:
160-
raw = imageio.imread(input_path)
161-
chunks = (64, 64, 64)
162-
elif s3 is not None:
163-
with zarr.open(input_path, mode="r") as fin:
164-
raw = fin[input_key]
165-
chunks = raw.chunks
166-
else:
164+
try:
165+
raw = tifffile.memmap(input_path, mode="r")
166+
except ValueError:
167+
print(f"Could not memmap the data from {input_path}. Fall back to load it into memory.")
168+
raw = imageio.imread(input_path)
169+
elif isinstance(input_path, str):
167170
fin = open_file(input_path, "r")
168171
raw = fin[input_key]
169-
chunks = (64, 64, 64)
172+
else:
173+
with zarr.open(input_path, mode="r") as fin:
174+
raw = fin[input_key]
175+
176+
chunks = getattr(raw, "chunks", (64,64,64))
170177

171178
block_shape = tuple(2 * ch for ch in chunks)
172179
blocking = nt.blocking([0, 0, 0], raw.shape, block_shape)
@@ -318,9 +325,7 @@ def run_unet_prediction_preprocess_slurm(
318325
and stored in a JSON file within the output folder as mean_std.json.
319326
"""
320327
if s3 is not None:
321-
bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
322-
323-
input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
328+
input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
324329

325330
if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
326331
find_mask(input_path, input_key, output_folder, s3=s3)
@@ -355,9 +360,7 @@ def run_unet_prediction_slurm(
355360
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
356361

357362
if s3 is not None:
358-
bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
359-
360-
input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
363+
input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
361364

362365
if slurm_task_id is not None:
363366
slurm_task_id = int(slurm_task_id)

scripts/extract_block.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ def main(
6666
roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo))
6767

6868
if s3:
69-
bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
70-
71-
s3_path, fs = s3_utils.get_s3_path(input_file, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
69+
s3_path, fs = s3_utils.get_s3_path(input_file, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
7270

7371
with zarr.open(s3_path, mode="r") as f:
7472
raw = f[input_key][roi]

scripts/prediction/count_cells.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def main():
2828
raise ValueError("Either provide an output_folder containing 'segmentation.zarr' or an S3 input.")
2929

3030
if args.s3_input is not None:
31-
bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials)
32-
33-
s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
31+
s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials)
3432
with zarr.open(s3_path, mode="r") as f:
3533
dataset = f[args.input_key]
3634

scripts/prediction/expand_seg_table.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def main(
1919
:param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported
2020
"""
2121
if s3:
22-
bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
23-
tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
22+
tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
2423
with fs.open(tsv_path, 'r') as f:
2524
tsv_table = pd.read_csv(f, sep="\t")
2625
else:

scripts/prediction/postprocess_seg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ def main():
3737
tsv_table=None
3838

3939
if args.s3_input is not None:
40-
bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials)
41-
42-
s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
40+
s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials)
4341
with zarr.open(s3_path, mode="r") as f:
4442
segmentation = f[args.input_key]
4543

4644
if args.tsv is not None:
47-
tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
45+
tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials)
4846
with fs.open(tsv_path, 'r') as f:
4947
tsv_table = pd.read_csv(f, sep="\t")
5048

0 commit comments

Comments
 (0)