Skip to content

Commit f94fd19

Browse files
committed
Improve post-processing performance
* Use multiplication instead of division * Avoid splitting and re-stacking tensors to reduce memory bandwidth and CPU-GPU syncs
1 parent 6131a93 commit f94fd19

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

src/diffusers/image_processor.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
236236
`np.ndarray` or `torch.Tensor`:
237237
The denormalized image array.
238238
"""
239-
return (images / 2 + 0.5).clamp(0, 1)
239+
return (images * 0.5 + 0.5).clamp(0, 1)
240240

241241
@staticmethod
242242
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
@@ -537,6 +537,27 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
537537

538538
return image
539539

540+
def _denormalize_conditionally(self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None) -> torch.Tensor:
541+
r"""
542+
Denormalize a batch of images based on a condition list.
543+
544+
Args:
545+
images (`torch.Tensor`):
546+
The input image tensor.
547+
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
548+
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
549+
value of `do_normalize` in the `VaeImageProcessor` config.
550+
"""
551+
if do_denormalize is None:
552+
return self.denormalize(images) if self.config.do_normalize else images
553+
554+
# De-normalizing a batch and selectively torch.stack'ing the results turns out to be
555+
# significantly faster than performing a lot of smaller denormalizations
556+
denormalized = self.denormalize(images)
557+
return torch.stack([
558+
denormalized[i] if do_denormalize[i] else images[i] for i in range(images.shape[0])
559+
])
560+
540561
def get_default_height_width(
541562
self,
542563
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
@@ -752,12 +773,7 @@ def postprocess(
752773
if output_type == "latent":
753774
return image
754775

755-
if do_denormalize is None:
756-
do_denormalize = [self.config.do_normalize] * image.shape[0]
757-
758-
image = torch.stack(
759-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
760-
)
776+
image = self._denormalize_conditionally(image, do_denormalize)
761777

762778
if output_type == "pt":
763779
return image
@@ -966,12 +982,7 @@ def postprocess(
966982
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
967983
output_type = "np"
968984

969-
if do_denormalize is None:
970-
do_denormalize = [self.config.do_normalize] * image.shape[0]
971-
972-
image = torch.stack(
973-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
974-
)
985+
image = self._denormalize_conditionally(image, do_denormalize)
975986

976987
image = self.pt_to_numpy(image)
977988

0 commit comments

Comments
 (0)