@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
236236            `np.ndarray` or `torch.Tensor`: 
237237                The denormalized image array. 
238238        """ 
239-         return  (images  /   2  +  0.5 ).clamp (0 , 1 )
239+         return  (images  *   0.5  +  0.5 ).clamp (0 , 1 )
240240
241241    @staticmethod  
242242    def  convert_to_rgb (image : PIL .Image .Image ) ->  PIL .Image .Image :
@@ -537,6 +537,26 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
537537
538538        return  image 
539539
540+     def  _denormalize_conditionally (
541+         self , images : torch .Tensor , do_denormalize : Optional [List [bool ]] =  None 
542+     ) ->  torch .Tensor :
543+         r""" 
544+         Denormalize a batch of images based on a condition list. 
545+ 
546+         Args: 
547+             images (`torch.Tensor`): 
548+                 The input image tensor. 
549+             do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): 
550+                 A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the 
551+                 value of `do_normalize` in the `VaeImageProcessor` config. 
552+         """ 
553+         if  do_denormalize  is  None :
554+             return  self .denormalize (images ) if  self .config .do_normalize  else  images 
555+ 
556+         return  torch .stack (
557+             [self .denormalize (images [i ]) if  do_denormalize [i ] else  images [i ] for  i  in  range (images .shape [0 ])]
558+         )
559+ 
540560    def  get_default_height_width (
541561        self ,
542562        image : Union [PIL .Image .Image , np .ndarray , torch .Tensor ],
@@ -752,12 +772,7 @@ def postprocess(
752772        if  output_type  ==  "latent" :
753773            return  image 
754774
755-         if  do_denormalize  is  None :
756-             do_denormalize  =  [self .config .do_normalize ] *  image .shape [0 ]
757- 
758-         image  =  torch .stack (
759-             [self .denormalize (image [i ]) if  do_denormalize [i ] else  image [i ] for  i  in  range (image .shape [0 ])]
760-         )
775+         image  =  self ._denormalize_conditionally (image , do_denormalize )
761776
762777        if  output_type  ==  "pt" :
763778            return  image 
@@ -966,12 +981,7 @@ def postprocess(
966981            deprecate ("Unsupported output_type" , "1.0.0" , deprecation_message , standard_warn = False )
967982            output_type  =  "np" 
968983
969-         if  do_denormalize  is  None :
970-             do_denormalize  =  [self .config .do_normalize ] *  image .shape [0 ]
971- 
972-         image  =  torch .stack (
973-             [self .denormalize (image [i ]) if  do_denormalize [i ] else  image [i ] for  i  in  range (image .shape [0 ])]
974-         )
984+         image  =  self ._denormalize_conditionally (image , do_denormalize )
975985
976986        image  =  self .pt_to_numpy (image )
977987
0 commit comments