@@ -338,8 +338,10 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
338338 return {"spec" : "ij.IJ::runMacroFile" , "kwargs" : macro }
339339
340340
341- def _get_deepimagej_config (export_folder , sample_inputs , sample_outputs , pixel_sizes , preprocessing , postprocessing ):
342- assert len (sample_inputs ) == len (sample_outputs ) == 1 , "deepimagej config only valid for single input/output"
341+ def _get_deepimagej_config (
342+ export_folder , test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , preprocessing , postprocessing
343+ ):
344+ assert len (test_inputs ) == len (test_outputs ) == 1 , "deepimagej config only valid for single input/output"
343345
344346 if any (preproc is not None for preproc in preprocessing ):
345347 assert len (preprocessing ) == 1
@@ -363,13 +365,21 @@ def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_s
363365 else :
364366 postprocess_ij = [{"spec" : None }]
365367
366- def get_size (path ):
367- assert tifffile is not None , "need tifffile for writing deepimagej config"
368- with tifffile .TiffFile (export_folder / path ) as f :
369- shape = f .asarray ().shape
370- # add singleton z axis if we have 2d data
368+ def get_size (fname , axes ):
369+ shape = np .load (export_folder / fname ).shape
370+ assert len (shape ) == len (axes )
371+ shape = [sh for sh , ax in zip (shape , axes ) if ax != "b" ]
372+ axes = [ax for ax in axes if ax != "b" ]
373+ # the shape for deepij is always given as xyzc
374+ if len (shape ) == 3 :
375+ axes_ij = "xyc"
376+ else :
377+ axes_ij = "xyzc"
378+ assert set (axes ) == set (axes_ij )
379+ axis_permutation = [axes_ij .index (ax ) for ax in axes ]
380+ shape = [shape [permut ] for permut in axis_permutation ]
371381 if len (shape ) == 3 :
372- shape = shape [:2 ] + ( 1 ,) + shape [- 1 :]
382+ shape = shape [:2 ] + [ 1 ] + shape [- 1 :]
373383 assert len (shape ) == 4
374384 return " x " .join (map (str , shape ))
375385
@@ -378,10 +388,13 @@ def get_size(path):
378388
379389 test_info = {
380390 "inputs" : [
381- {"name" : in_path , "size" : get_size (in_path ), "pixel_size" : pix_size }
382- for in_path , pix_size in zip (sample_inputs , pixel_sizes_ )
391+ {"name" : in_path , "size" : get_size (in_path , axes ), "pixel_size" : pix_size }
392+ for in_path , axes , pix_size in zip (test_inputs , input_axes , pixel_sizes_ )
393+ ],
394+ "outputs" : [
395+ {"name" : out_path , "type" : "image" , "size" : get_size (out_path , axes )}
396+ for out_path , axes in zip (test_outputs , output_axes )
383397 ],
384- "outputs" : [{"name" : out_path , "type" : "image" , "size" : get_size (out_path )} for out_path in sample_outputs ],
385398 "memory_peak" : None ,
386399 "runtime" : None ,
387400 }
@@ -397,36 +410,49 @@ def get_size(path):
397410 return {"deepimagej" : config }, [Path (a ) for a in attachments ]
398411
399412
400- def _write_sample_data (input_paths , output_paths , input_axes , output_axes , export_folder : Path ):
401- def write_im (path , im , axes ):
413+ def _write_sample_data (input_paths , output_paths , input_axes , output_axes , pixel_sizes , export_folder : Path ):
414+ def write_im (path , im , axes , pixel_size = None ):
402415 assert tifffile is not None , "need tifffile for writing deepimagej config"
403- assert len (axes ) == im .ndim
404- assert im .ndim in (3 , 4 )
416+ assert len (axes ) == im .ndim , f" { len ( axes ), { im . ndim } } "
417+ assert im .ndim in (4 , 5 ), f" { im . ndim } "
405418
406- # deepimagej expects xyzc axis order
407- if im .ndim == 3 :
408- assert set (axes ) == {"x" , "y" , "c" }
409- axes_ij = "xyc "
419+ # convert the image to expects (Z)CYX axis order
420+ if im .ndim == 4 :
421+ assert set (axes ) == {"b" , " x" , "y" , "c" }, f" { axes } "
422+ axes_ij = "cyxb "
410423 else :
411- assert set (axes ) == {"x" , "y" , "z" , "c" }
412- axes_ij = "xyzc "
424+ assert set (axes ) == {"b" , " x" , "y" , "z" , "c" }, f" { axes } "
425+ axes_ij = "zcyxb "
413426
414- axis_permutation = tuple (axes_ij .index (ax ) for ax in axes )
427+ axis_permutation = tuple (axes .index (ax ) for ax in axes_ij )
415428 im = im .transpose (axis_permutation )
416-
417- with tifffile .TiffWriter (path ) as f :
418- f .write (im )
429+ # expand to TZCYXS
430+ if len (axes_ij ) == 4 : # add singleton t and z axis
431+ im = im [None , None ]
432+ else : # add singeton z axis
433+ im = im [None ]
434+
435+ if pixel_size is None :
436+ resolution = None
437+ else :
438+ spatial_axes = list (set (axes_ij ) - set ("bc" ))
439+ resolution = tuple (1.0 / pixel_size [ax ] for ax in axes_ij if ax in spatial_axes )
440+ # does not work for double
441+ if np .dtype (im .dtype ) == np .dtype ("float64" ):
442+ im = im .astype ("float32" )
443+ tifffile .imsave (path , im , imagej = True , resolution = resolution )
419444
420445 sample_in_paths = []
421446 for i , (in_path , axes ) in enumerate (zip (input_paths , input_axes )):
422- inp = np .load (export_folder / in_path )[ 0 ]
447+ inp = np .load (export_folder / in_path )
423448 sample_in_path = export_folder / f"sample_input_{ i } .tif"
424- write_im (sample_in_path , inp , axes )
449+ pixel_size = None if pixel_sizes is None else pixel_sizes [i ]
450+ write_im (sample_in_path , inp , axes , pixel_size )
425451 sample_in_paths .append (sample_in_path )
426452
427453 sample_out_paths = []
428454 for i , (out_path , axes ) in enumerate (zip (output_paths , output_axes )):
429- outp = np .load (export_folder / out_path )[ 0 ]
455+ outp = np .load (export_folder / out_path )
430456 sample_out_path = export_folder / f"sample_output_{ i } .tif"
431457 write_im (sample_out_path , outp , axes )
432458 sample_out_paths .append (sample_out_path )
@@ -797,17 +823,15 @@ def build_model(
797823 # add the deepimagej config if specified
798824 if add_deepimagej_config :
799825 if sample_inputs is None :
800- input_axes_ij = [inp .axes [1 :] for inp in inputs ]
801- output_axes_ij = [out .axes [1 :] for out in outputs ]
802826 sample_inputs , sample_outputs = _write_sample_data (
803- test_inputs , test_outputs , input_axes_ij , output_axes_ij , root
827+ test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , root
804828 )
805829 # deepimagej expect tifs as sample data
806830 assert all (os .path .splitext (path )[1 ] in (".tif" , ".tiff" ) for path in sample_inputs )
807831 assert all (os .path .splitext (path )[1 ] in (".tif" , ".tiff" ) for path in sample_outputs )
808832
809833 ij_config , ij_attachments = _get_deepimagej_config (
810- root , sample_inputs , sample_outputs , pixel_sizes , preprocessing , postprocessing
834+ root , test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , preprocessing , postprocessing
811835 )
812836
813837 if config is None :
0 commit comments