@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
236236 `np.ndarray` or `torch.Tensor`:
237237 The denormalized image array.
238238 """
239- return (images / 2 + 0.5 ).clamp (0 , 1 )
239+ return (images * 0.5 + 0.5 ).clamp (0 , 1 )
240240
241241 @staticmethod
242242 def convert_to_rgb (image : PIL .Image .Image ) -> PIL .Image .Image :
@@ -537,6 +537,27 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
537537
538538 return image
539539
540+ def _denormalize_conditionally (self , images : torch .Tensor , do_denormalize : Optional [List [bool ]] = None ) -> torch .Tensor :
541+ r"""
542+ Denormalize a batch of images based on a condition list.
543+
544+ Args:
545+ images (`torch.Tensor`):
546+ The input image tensor.
547+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
548+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
549+ value of `do_normalize` in the `VaeImageProcessor` config.
550+ """
551+ if do_denormalize is None :
552+ return self .denormalize (images ) if self .config .do_normalize else images
553+
554+ # De-normalizing a batch and selectively torch.stack'ing the results turns out to be
555+ # significantly faster than performing a lot of smaller denormalizations
556+ denormalized = self .denormalize (images )
557+ return torch .stack ([
558+ denormalized [i ] if do_denormalize [i ] else images [i ] for i in range (images .shape [0 ])
559+ ])
560+
540561 def get_default_height_width (
541562 self ,
542563 image : Union [PIL .Image .Image , np .ndarray , torch .Tensor ],
@@ -752,12 +773,7 @@ def postprocess(
752773 if output_type == "latent" :
753774 return image
754775
755- if do_denormalize is None :
756- do_denormalize = [self .config .do_normalize ] * image .shape [0 ]
757-
758- image = torch .stack (
759- [self .denormalize (image [i ]) if do_denormalize [i ] else image [i ] for i in range (image .shape [0 ])]
760- )
776+ image = self ._denormalize_conditionally (image , do_denormalize )
761777
762778 if output_type == "pt" :
763779 return image
@@ -966,12 +982,7 @@ def postprocess(
966982 deprecate ("Unsupported output_type" , "1.0.0" , deprecation_message , standard_warn = False )
967983 output_type = "np"
968984
969- if do_denormalize is None :
970- do_denormalize = [self .config .do_normalize ] * image .shape [0 ]
971-
972- image = torch .stack (
973- [self .denormalize (image [i ]) if do_denormalize [i ] else image [i ] for i in range (image .shape [0 ])]
974- )
985+ image = self ._denormalize_conditionally (image , do_denormalize )
975986
976987 image = self .pt_to_numpy (image )
977988
0 commit comments