@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
236236 `np.ndarray` or `torch.Tensor`:
237237 The denormalized image array.
238238 """
239- return (images / 2 + 0.5 ).clamp (0 , 1 )
239+ return (images * 0.5 + 0.5 ).clamp (0 , 1 )
240240
241241 @staticmethod
242242 def convert_to_rgb (image : PIL .Image .Image ) -> PIL .Image .Image :
@@ -537,6 +537,26 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
537537
538538 return image
539539
540+ def _denormalize_conditionally (
541+ self , images : torch .Tensor , do_denormalize : Optional [List [bool ]] = None
542+ ) -> torch .Tensor :
543+ r"""
544+ Denormalize a batch of images based on a condition list.
545+
546+ Args:
547+ images (`torch.Tensor`):
548+ The input image tensor.
549+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
550+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
551+ value of `do_normalize` in the `VaeImageProcessor` config.
552+ """
553+ if do_denormalize is None :
554+ return self .denormalize (images ) if self .config .do_normalize else images
555+
556+ return torch .stack (
557+ [self .denormalize (images [i ]) if do_denormalize [i ] else images [i ] for i in range (images .shape [0 ])]
558+ )
559+
540560 def get_default_height_width (
541561 self ,
542562 image : Union [PIL .Image .Image , np .ndarray , torch .Tensor ],
@@ -752,12 +772,7 @@ def postprocess(
752772 if output_type == "latent" :
753773 return image
754774
755- if do_denormalize is None :
756- do_denormalize = [self .config .do_normalize ] * image .shape [0 ]
757-
758- image = torch .stack (
759- [self .denormalize (image [i ]) if do_denormalize [i ] else image [i ] for i in range (image .shape [0 ])]
760- )
775+ image = self ._denormalize_conditionally (image , do_denormalize )
761776
762777 if output_type == "pt" :
763778 return image
@@ -966,12 +981,7 @@ def postprocess(
966981 deprecate ("Unsupported output_type" , "1.0.0" , deprecation_message , standard_warn = False )
967982 output_type = "np"
968983
969- if do_denormalize is None :
970- do_denormalize = [self .config .do_normalize ] * image .shape [0 ]
971-
972- image = torch .stack (
973- [self .denormalize (image [i ]) if do_denormalize [i ] else image [i ] for i in range (image .shape [0 ])]
974- )
984+ image = self ._denormalize_conditionally (image , do_denormalize )
975985
976986 image = self .pt_to_numpy (image )
977987
0 commit comments