@@ -413,46 +413,46 @@ def get_size(fname, axes):
413413def _write_sample_data (input_paths , output_paths , input_axes , output_axes , pixel_sizes , export_folder : Path ):
414414 def write_im (path , im , axes , pixel_size = None ):
415415 assert tifffile is not None , "need tifffile for writing deepimagej config"
416- assert len (axes ) == im .ndim
417- assert im .ndim in (3 , 4 )
416+ assert len (axes ) == im .ndim , f" { len ( axes ), { im . ndim } } "
417+ assert im .ndim in (4 , 5 )
418418
419419 # convert the image to expects (Z)CYX axis order
420420 if im .ndim == 3 :
421- assert set (axes ) == {"x" , "y" , "c" }
422- axes_ij = "cyx "
421+ assert set (axes ) == {"b" , " x" , "y" , "c" }
422+ axes_ij = "cyxb "
423423 else :
424- assert set (axes ) == {"x" , "y" , "z" , "c" }
425- axes_ij = "zcyx "
424+ assert set (axes ) == {"b" , " x" , "y" , "z" , "c" }
425+ axes_ij = "zcyxb "
426426
427- axis_permutation = tuple (axes_ij .index (ax ) for ax in axes )
427+ axis_permutation = tuple (axes .index (ax ) for ax in axes_ij )
428428 im = im .transpose (axis_permutation )
429429 # expand to TZCYXS
430- if len (axes_ij ) == 2 : # add singleton z axis
431- im = im [None , None , ..., None ]
432- else :
433- im = im [None , ..., None ]
430+ if len (axes_ij ) == 2 : # add singleton t and z axis
431+ im = im [None , None ]
432+ else : # add singeton z axis
433+ im = im [None ]
434434
435435 if pixel_size is None :
436436 resolution = None
437437 else :
438- spatial_axes = list (set (axes_ij ) - set ([ "c" ] ))
439- resolution = tuple (1.0 / pixel_size [ax ] for ax in spatial_axes )
438+ spatial_axes = list (set (axes_ij ) - set ("bc" ))
439+ resolution = tuple (1.0 / pixel_size [ax ] for ax in axes_ij if ax in spatial_axes )
440440 # does not work for double
441441 if np .dtype (im .dtype ) == np .dtype ("float64" ):
442442 im = im .astype ("float32" )
443443 tifffile .imsave (path , im , imagej = True , resolution = resolution )
444444
445445 sample_in_paths = []
446446 for i , (in_path , axes ) in enumerate (zip (input_paths , input_axes )):
447- inp = np .load (export_folder / in_path )[ 0 ]
447+ inp = np .load (export_folder / in_path )
448448 sample_in_path = export_folder / f"sample_input_{ i } .tif"
449449 pixel_size = None if pixel_sizes is None else pixel_sizes [i ]
450450 write_im (sample_in_path , inp , axes , pixel_size )
451451 sample_in_paths .append (sample_in_path )
452452
453453 sample_out_paths = []
454454 for i , (out_path , axes ) in enumerate (zip (output_paths , output_axes )):
455- outp = np .load (export_folder / out_path )[ 0 ]
455+ outp = np .load (export_folder / out_path )
456456 sample_out_path = export_folder / f"sample_output_{ i } .tif"
457457 write_im (sample_out_path , outp , axes )
458458 sample_out_paths .append (sample_out_path )
0 commit comments