Skip to content

Commit 3fbf0c5

Browse files
committed
improve get_io_sample_block_metas
1 parent 18bed26 commit 3fbf0c5

File tree

1 file changed

+4
-34
lines changed

1 file changed

+4
-34
lines changed

bioimageio/core/digest_spec.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,6 @@ def get_io_sample_block_metas(
287287
t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
288288
for t in {tt for tt, _ in block_axis_sizes.inputs}
289289
}
290-
output_block_shape = {
291-
t: {
292-
aa: s
293-
for (tt, aa), s in block_axis_sizes.outputs.items()
294-
if tt == t and not isinstance(s, tuple)
295-
}
296-
for t in {tt for tt, _ in block_axis_sizes.outputs}
297-
}
298290
output_halo = {
299291
t.id: {
300292
a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
@@ -303,36 +295,14 @@ def get_io_sample_block_metas(
303295
}
304296
input_halo = get_input_halo(model, output_halo)
305297

306-
# TODO: fix output_sample_shape_data_dep
307-
# (below only valid if input_sample_shape is a valid model input,
308-
# which is not a valid assumption)
309-
output_sample_shape_data_dep = model.get_output_tensor_sizes(input_sample_shape)
310-
311-
output_sample_shape = {
312-
t: {
313-
a: -1 if isinstance(s, tuple) else s
314-
for a, s in output_sample_shape_data_dep[t].items()
315-
}
316-
for t in output_sample_shape_data_dep
317-
}
318298
n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
319299
input_sample_shape, input_block_shape, halo=input_halo
320300
)
321-
n_output_blocks, output_blocks = split_multiple_shapes_into_blocks(
322-
output_sample_shape, output_block_shape, halo=output_halo
323-
)
324-
assert n_input_blocks == n_output_blocks
301+
block_transform = get_block_transform(model)
325302
return n_input_blocks, (
326-
IO_SampleBlockMeta(ipt, out)
327-
for ipt, out in zip(
328-
sample_block_meta_generator(
329-
input_blocks, sample_shape=input_sample_shape, sample_id=None
330-
),
331-
sample_block_meta_generator(
332-
output_blocks,
333-
sample_shape=output_sample_shape,
334-
sample_id=None,
335-
),
303+
IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
304+
for ipt in sample_block_meta_generator(
305+
input_blocks, sample_shape=input_sample_shape, sample_id=None
336306
)
337307
)
338308

0 commit comments

Comments
 (0)