Skip to content

Commit a475402

Browse files
authored
add all the IJ dimensions to the image
1 parent e5541e0 commit a475402

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -420,24 +420,32 @@ def write_im(path, im, axes, pixel_size=None):
420420
# convert the image to expects (Z)CYX axis order
421421
if im.ndim == 4:
422422
assert set(axes) == {"b", "x", "y", "c"}, f"{axes}"
423-
axes_ij = "bcyx"
423+
resolution_axes_ij = "cyxb"
424424
else:
425425
assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}"
426-
axes_ij = "bzcyx"
427-
428-
axis_permutation = tuple(axes.index(ax) for ax in axes_ij)
426+
resolution_axes_ij = "bzcyx"
427+
428+
def addMissingAxes(im_axes):
429+
needed_axes = ["b", "c", "x", "y", "z", "s"]
430+
for ax in needed_axes:
431+
if not ax in im_axes:
432+
im_axes += ax
433+
return im_axes
434+
435+
axes_ij = "bzcyxs"
436+
# Expand the image to ImageJ dimensions
437+
im = np.expand_dims(im, axis=tuple(range(len(axes),len(axes_ij))))
438+
439+
440+
axis_permutation = tuple(addMissingAxes(axes).index(ax) for ax in axes_ij)
429441
im = im.transpose(axis_permutation)
430-
# expand to TZCYXS
431-
if len(axes_ij) == 4: # add singleton t and z axis
432-
im = im[None, None]
433-
else: # add singeton z axis
434-
im = im[None]
442+
435443

436444
if pixel_size is None:
437445
resolution = None
438446
else:
439-
spatial_axes = list(set(axes_ij) - set("bc"))
440-
resolution = tuple(1.0 / pixel_size[ax] for ax in axes_ij if ax in spatial_axes)
447+
spatial_axes = list(set(resolution_axes_ij) - set("bc"))
448+
resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes)
441449
# does not work for double
442450
if np.dtype(im.dtype) == np.dtype("float64"):
443451
im = im.astype("float32")

tests/build_spec/test_build_spec.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,6 @@ def _test_build_spec(
9494
input_names=[inp.name for inp in model_spec.inputs],
9595
output_names=[out.name for out in model_spec.outputs],
9696
)
97-
print("**********************************************")
98-
print("**********************************************")
99-
print("**********************************************")
100-
print(model_spec.name)
101-
print(model_spec.root_path)
102-
print(model_spec.rdf_source)
10397
if architecture is not None:
10498
kwargs["architecture"] = architecture
10599
if model_kwargs is not None:

0 commit comments

Comments
 (0)