@@ -287,14 +287,6 @@ def get_io_sample_block_metas(
287287 t : {aa : s for (tt , aa ), s in block_axis_sizes .inputs .items () if tt == t }
288288 for t in {tt for tt , _ in block_axis_sizes .inputs }
289289 }
290- output_block_shape = {
291- t : {
292- aa : s
293- for (tt , aa ), s in block_axis_sizes .outputs .items ()
294- if tt == t and not isinstance (s , tuple )
295- }
296- for t in {tt for tt , _ in block_axis_sizes .outputs }
297- }
298290 output_halo = {
299291 t .id : {
300292 a .id : Halo (a .halo , a .halo ) for a in t .axes if isinstance (a , v0_5 .WithHalo )
@@ -303,36 +295,14 @@ def get_io_sample_block_metas(
303295 }
304296 input_halo = get_input_halo (model , output_halo )
305297
306- # TODO: fix output_sample_shape_data_dep
307- # (below only valid if input_sample_shape is a valid model input,
308- # which is not a valid assumption)
309- output_sample_shape_data_dep = model .get_output_tensor_sizes (input_sample_shape )
310-
311- output_sample_shape = {
312- t : {
313- a : - 1 if isinstance (s , tuple ) else s
314- for a , s in output_sample_shape_data_dep [t ].items ()
315- }
316- for t in output_sample_shape_data_dep
317- }
318298 n_input_blocks , input_blocks = split_multiple_shapes_into_blocks (
319299 input_sample_shape , input_block_shape , halo = input_halo
320300 )
321- n_output_blocks , output_blocks = split_multiple_shapes_into_blocks (
322- output_sample_shape , output_block_shape , halo = output_halo
323- )
324- assert n_input_blocks == n_output_blocks
301+ block_transform = get_block_transform (model )
325302 return n_input_blocks , (
326- IO_SampleBlockMeta (ipt , out )
327- for ipt , out in zip (
328- sample_block_meta_generator (
329- input_blocks , sample_shape = input_sample_shape , sample_id = None
330- ),
331- sample_block_meta_generator (
332- output_blocks ,
333- sample_shape = output_sample_shape ,
334- sample_id = None ,
335- ),
303+ IO_SampleBlockMeta (ipt , ipt .get_transformed (block_transform ))
304+ for ipt in sample_block_meta_generator (
305+ input_blocks , sample_shape = input_sample_shape , sample_id = None
336306 )
337307 )
338308
0 commit comments