@@ -47,11 +47,6 @@ def scale_mean_variance(tensor: xr.DataArray, *, reference_tensor, mode="per_sam
4747 raise NotImplementedError
4848
4949
50- # NOTE there is a subtle difference between pre-and-postprocessing:
51- # pre-processing always returns float32, because the post-processing output is consumed
52- # by the model. Post-processing, however, should return the dtype that is specified in the model spec
53- # TODO I think the easiest way to implement this is to add dtype is an option to 'make_postprocessing'
54- # and then apply 'ensure_dtype' to the result of the postprocessing chain
5550KNOWN_POSTPROCESSING = {
5651 "binarize" : ops .binarize ,
5752 "clip" : ops .clip ,
@@ -62,7 +57,7 @@ def scale_mean_variance(tensor: xr.DataArray, *, reference_tensor, mode="per_sam
6257}
6358
6459
65- def make_postprocessing (spec : List [Postprocessing ]) -> Transform :
60+ def make_postprocessing (spec : List [Postprocessing ], dtype : str ) -> Transform :
6661 """
6762 :param preprocessing: bioimage-io spec node
6863 """
@@ -78,4 +73,9 @@ def make_postprocessing(spec: List[Postprocessing]) -> Transform:
7873
7974 functions .append ((fn , kwargs ))
8075
76+ # There is a difference between pre-and-postprocessing:
77+ # Tre-processing always returns float32, because its output is consumed y the model.
78+ # Post-processing, however, should return the dtype that is specified in the model spec.
79+ functions .append ((ops .ensure_dtype , {"dtype" : dtype }))
80+
8181 return ops .chain (* functions )
0 commit comments