|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +""" | 
|  | 3 | +Minimal image processing utilities for FastVideo. | 
|  | 4 | +This module provides lightweight image preprocessing without external dependencies beyond PyTorch/NumPy/PIL. | 
|  | 5 | +""" | 
|  | 6 | + | 
|  | 7 | +from typing import Optional, Union | 
|  | 8 | + | 
|  | 9 | +import numpy as np | 
|  | 10 | +import PIL.Image | 
|  | 11 | +import torch | 
|  | 12 | + | 
|  | 13 | + | 
|  | 14 | +class ImageProcessor: | 
|  | 15 | +    """ | 
|  | 16 | +    Minimal image processor for video frame preprocessing. | 
|  | 17 | +
 | 
|  | 18 | +    This is a lightweight alternative to diffusers.VideoProcessor that handles: | 
|  | 19 | +    - PIL image to tensor conversion | 
|  | 20 | +    - Resizing to specified dimensions | 
|  | 21 | +    - Normalization to [-1, 1] range | 
|  | 22 | +
 | 
|  | 23 | +    Args: | 
|  | 24 | +        vae_scale_factor: The VAE scale factor used to ensure dimensions are multiples of this value. | 
|  | 25 | +    """ | 
|  | 26 | + | 
|  | 27 | +    def __init__(self, vae_scale_factor: int = 8) -> None: | 
|  | 28 | +        self.vae_scale_factor = vae_scale_factor | 
|  | 29 | + | 
|  | 30 | +    def preprocess( | 
|  | 31 | +        self, | 
|  | 32 | +        image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], | 
|  | 33 | +        height: Optional[int] = None, | 
|  | 34 | +        width: Optional[int] = None, | 
|  | 35 | +    ) -> torch.Tensor: | 
|  | 36 | +        """ | 
|  | 37 | +        Preprocess an image to a normalized torch tensor. | 
|  | 38 | +
 | 
|  | 39 | +        Args: | 
|  | 40 | +            image: Input image (PIL Image, NumPy array, or torch tensor) | 
|  | 41 | +            height: Target height. If None, uses image's original height. | 
|  | 42 | +            width: Target width. If None, uses image's original width. | 
|  | 43 | +
 | 
|  | 44 | +        Returns: | 
|  | 45 | +            torch.Tensor: Normalized tensor of shape (1, 3, height, width) or (1, 1, height, width) for grayscale, | 
|  | 46 | +                         with values in range [-1, 1]. | 
|  | 47 | +        """ | 
|  | 48 | +        # Handle different input types | 
|  | 49 | +        if isinstance(image, PIL.Image.Image): | 
|  | 50 | +            return self._preprocess_pil(image, height, width) | 
|  | 51 | +        elif isinstance(image, np.ndarray): | 
|  | 52 | +            return self._preprocess_numpy(image, height, width) | 
|  | 53 | +        elif isinstance(image, torch.Tensor): | 
|  | 54 | +            return self._preprocess_tensor(image, height, width) | 
|  | 55 | +        else: | 
|  | 56 | +            raise ValueError( | 
|  | 57 | +                f"Unsupported image type: {type(image)}. " | 
|  | 58 | +                "Supported types: PIL.Image.Image, np.ndarray, torch.Tensor" | 
|  | 59 | +            ) | 
|  | 60 | + | 
|  | 61 | +    def _preprocess_pil( | 
|  | 62 | +        self, | 
|  | 63 | +        image: PIL.Image.Image, | 
|  | 64 | +        height: Optional[int] = None, | 
|  | 65 | +        width: Optional[int] = None, | 
|  | 66 | +    ) -> torch.Tensor: | 
|  | 67 | +        """Preprocess a PIL image.""" | 
|  | 68 | +        if height is None: | 
|  | 69 | +            height = image.height | 
|  | 70 | +        if width is None: | 
|  | 71 | +            width = image.width | 
|  | 72 | + | 
|  | 73 | +        height = height - (height % self.vae_scale_factor) | 
|  | 74 | +        width = width - (width % self.vae_scale_factor) | 
|  | 75 | + | 
|  | 76 | +        image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS) | 
|  | 77 | + | 
|  | 78 | +        image_np = np.array(image, dtype=np.float32) / 255.0 | 
|  | 79 | + | 
|  | 80 | +        if image_np.ndim == 2:  # Grayscale | 
|  | 81 | +            image_np = np.expand_dims(image_np, axis=-1) | 
|  | 82 | + | 
|  | 83 | +        return self._normalize_to_tensor(image_np) | 
|  | 84 | + | 
|  | 85 | +    def _preprocess_numpy( | 
|  | 86 | +        self, | 
|  | 87 | +        image: np.ndarray, | 
|  | 88 | +        height: Optional[int] = None, | 
|  | 89 | +        width: Optional[int] = None, | 
|  | 90 | +    ) -> torch.Tensor: | 
|  | 91 | +        """Preprocess a numpy array.""" | 
|  | 92 | +        # Determine target dimensions if not provided | 
|  | 93 | +        if image.ndim == 3: | 
|  | 94 | +            img_height, img_width = image.shape[:2] | 
|  | 95 | +        elif image.ndim == 2: | 
|  | 96 | +            img_height, img_width = image.shape | 
|  | 97 | +        else: | 
|  | 98 | +            raise ValueError(f"Expected 2D or 3D array, got {image.ndim}D") | 
|  | 99 | + | 
|  | 100 | +        if height is None: | 
|  | 101 | +            height = img_height | 
|  | 102 | +        if width is None: | 
|  | 103 | +            width = img_width | 
|  | 104 | + | 
|  | 105 | +        height = height - (height % self.vae_scale_factor) | 
|  | 106 | +        width = width - (width % self.vae_scale_factor) | 
|  | 107 | + | 
|  | 108 | +        if image.dtype == np.uint8: | 
|  | 109 | +            pil_image = PIL.Image.fromarray(image) | 
|  | 110 | +        else: | 
|  | 111 | +            # Assume normalized [0, 1] or similar | 
|  | 112 | +            if image.max() <= 1.0: | 
|  | 113 | +                image_uint8 = (image * 255).astype(np.uint8) | 
|  | 114 | +            else: | 
|  | 115 | +                image_uint8 = image.astype(np.uint8) | 
|  | 116 | +            pil_image = PIL.Image.fromarray(image_uint8) | 
|  | 117 | + | 
|  | 118 | +        pil_image = pil_image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS) | 
|  | 119 | +        image_np = np.array(pil_image, dtype=np.float32) / 255.0 | 
|  | 120 | + | 
|  | 121 | +        # Ensure 3D shape | 
|  | 122 | +        if image_np.ndim == 2: | 
|  | 123 | +            image_np = np.expand_dims(image_np, axis=-1) | 
|  | 124 | + | 
|  | 125 | +        return self._normalize_to_tensor(image_np) | 
|  | 126 | + | 
|  | 127 | +    def _preprocess_tensor( | 
|  | 128 | +        self, | 
|  | 129 | +        image: torch.Tensor, | 
|  | 130 | +        height: Optional[int] = None, | 
|  | 131 | +        width: Optional[int] = None, | 
|  | 132 | +    ) -> torch.Tensor: | 
|  | 133 | +        """Preprocess a torch tensor.""" | 
|  | 134 | +        # Determine target dimensions | 
|  | 135 | +        if image.ndim == 3:  # (H, W, C) or (C, H, W) | 
|  | 136 | +            if image.shape[0] in (1, 3, 4):  # Likely (C, H, W) | 
|  | 137 | +                img_height, img_width = image.shape[1], image.shape[2] | 
|  | 138 | +            else:  # Likely (H, W, C) | 
|  | 139 | +                img_height, img_width = image.shape[0], image.shape[1] | 
|  | 140 | +        elif image.ndim == 2:  # (H, W) | 
|  | 141 | +            img_height, img_width = image.shape | 
|  | 142 | +        else: | 
|  | 143 | +            raise ValueError(f"Expected 2D or 3D tensor, got {image.ndim}D") | 
|  | 144 | + | 
|  | 145 | +        if height is None: | 
|  | 146 | +            height = img_height | 
|  | 147 | +        if width is None: | 
|  | 148 | +            width = img_width | 
|  | 149 | + | 
|  | 150 | +        height = height - (height % self.vae_scale_factor) | 
|  | 151 | +        width = width - (width % self.vae_scale_factor) | 
|  | 152 | + | 
|  | 153 | +        if image.ndim == 2: | 
|  | 154 | +            image = image.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W) | 
|  | 155 | +        elif image.ndim == 3: | 
|  | 156 | +            if image.shape[0] in (1, 3, 4):  # (C, H, W) | 
|  | 157 | +                image = image.unsqueeze(0)  # (1, C, H, W) | 
|  | 158 | +            else:  # (H, W, C) - need to rearrange | 
|  | 159 | +                image = image.permute(2, 0, 1).unsqueeze(0)  # (1, C, H, W) | 
|  | 160 | + | 
|  | 161 | +        image = torch.nn.functional.interpolate( | 
|  | 162 | +            image, size=(height, width), mode="bilinear", align_corners=False | 
|  | 163 | +        ) | 
|  | 164 | + | 
|  | 165 | +        if image.max() > 1.0:  # Assume [0, 255] range | 
|  | 166 | +            image = image / 255.0 | 
|  | 167 | + | 
|  | 168 | +        image = 2.0 * image - 1.0 | 
|  | 169 | + | 
|  | 170 | +        return image | 
|  | 171 | + | 
|  | 172 | +    def _normalize_to_tensor(self, image_np: np.ndarray) -> torch.Tensor: | 
|  | 173 | +        """ | 
|  | 174 | +        Convert normalized numpy array [0, 1] to torch tensor [-1, 1]. | 
|  | 175 | +
 | 
|  | 176 | +        Args: | 
|  | 177 | +            image_np: NumPy array with shape (H, W) or (H, W, C) with values in [0, 1] | 
|  | 178 | +
 | 
|  | 179 | +        Returns: | 
|  | 180 | +            torch.Tensor: Shape (1, C, H, W) or (1, 1, H, W) with values in [-1, 1] | 
|  | 181 | +        """ | 
|  | 182 | +        # Convert to tensor | 
|  | 183 | +        if image_np.ndim == 2:  # (H, W) - grayscale | 
|  | 184 | +            tensor = torch.from_numpy(image_np).unsqueeze(0).unsqueeze(0)  # (1, 1, H, W) | 
|  | 185 | +        elif image_np.ndim == 3:  # (H, W, C) | 
|  | 186 | +            tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)  # (1, C, H, W) | 
|  | 187 | +        else: | 
|  | 188 | +            raise ValueError(f"Expected 2D or 3D array, got {image_np.ndim}D") | 
|  | 189 | + | 
|  | 190 | +        # Normalize to [-1, 1] | 
|  | 191 | +        tensor = 2.0 * tensor - 1.0 | 
|  | 192 | + | 
|  | 193 | +        return tensor | 
0 commit comments