Skip to content

Commit 5a5207d

Browse files
committed
Improved style
1 parent 9d3a61f commit 5a5207d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
8484
print("Predict with CPU")
8585
gpu_ids = ["cpu"]
8686

87-
if None == mean or None == std:
87+
if mean is None or std is None:
8888
# Compute the global mean and standard deviation.
8989
n_threads = min(16, mp.cpu_count())
9090
mean, std = parallel.mean_and_std(
@@ -111,8 +111,7 @@ def postprocess(x):
111111

112112
blocking = nt.blocking([0] * ndim, shape, block_shape)
113113
n_blocks = blocking.numberOfBlocks
114-
iteration_ids = []
115-
if 1 != prediction_instances:
114+
if prediction_instances != 1:
116115
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
117116
slurm_iteration = iteration_ids[slurm_task_id]
118117
else:
@@ -264,7 +263,7 @@ def calc_mean_and_std(input_path, input_key, output_folder):
264263
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
265264
mask=image_mask
266265
)
267-
ddict = {"mean":str(mean), "std": str(std)}
266+
ddict = {"mean":mean, "std":std}
268267
with open(json_file, "w") as f:
269268
json.dump(ddict, f)
270269

0 commit comments

Comments
 (0)