44This module provides lightweight image preprocessing without external dependencies beyond PyTorch/NumPy/PIL.
55"""
66
7- from typing import Optional , Union
8-
97import numpy as np
108import PIL .Image
119import torch
@@ -29,9 +27,9 @@ def __init__(self, vae_scale_factor: int = 8) -> None:
2927
3028 def preprocess (
3129 self ,
32- image : Union [ PIL .Image .Image , np .ndarray , torch .Tensor ] ,
33- height : Optional [ int ] = None ,
34- width : Optional [ int ] = None ,
30+ image : PIL .Image .Image | np .ndarray | torch .Tensor ,
31+ height : int | None = None ,
32+ width : int | None = None ,
3533 ) -> torch .Tensor :
3634 """
3735 Preprocess an image to a normalized torch tensor.
@@ -55,14 +53,13 @@ def preprocess(
5553 else :
5654 raise ValueError (
5755 f"Unsupported image type: { type (image )} . "
58- "Supported types: PIL.Image.Image, np.ndarray, torch.Tensor"
59- )
56+ "Supported types: PIL.Image.Image, np.ndarray, torch.Tensor" )
6057
6158 def _preprocess_pil (
6259 self ,
6360 image : PIL .Image .Image ,
64- height : Optional [ int ] = None ,
65- width : Optional [ int ] = None ,
61+ height : int | None = None ,
62+ width : int | None = None ,
6663 ) -> torch .Tensor :
6764 """Preprocess a PIL image."""
6865 if height is None :
@@ -73,7 +70,8 @@ def _preprocess_pil(
7370 height = height - (height % self .vae_scale_factor )
7471 width = width - (width % self .vae_scale_factor )
7572
76- image = image .resize ((width , height ), resample = PIL .Image .Resampling .LANCZOS )
73+ image = image .resize ((width , height ),
74+ resample = PIL .Image .Resampling .LANCZOS )
7775
7876 image_np = np .array (image , dtype = np .float32 ) / 255.0
7977
@@ -85,8 +83,8 @@ def _preprocess_pil(
8583 def _preprocess_numpy (
8684 self ,
8785 image : np .ndarray ,
88- height : Optional [ int ] = None ,
89- width : Optional [ int ] = None ,
86+ height : int | None = None ,
87+ width : int | None = None ,
9088 ) -> torch .Tensor :
9189 """Preprocess a numpy array."""
9290 # Determine target dimensions if not provided
@@ -115,7 +113,8 @@ def _preprocess_numpy(
115113 image_uint8 = image .astype (np .uint8 )
116114 pil_image = PIL .Image .fromarray (image_uint8 )
117115
118- pil_image = pil_image .resize ((width , height ), resample = PIL .Image .Resampling .LANCZOS )
116+ pil_image = pil_image .resize ((width , height ),
117+ resample = PIL .Image .Resampling .LANCZOS )
119118 image_np = np .array (pil_image , dtype = np .float32 ) / 255.0
120119
121120 # Ensure 3D shape
@@ -127,8 +126,8 @@ def _preprocess_numpy(
127126 def _preprocess_tensor (
128127 self ,
129128 image : torch .Tensor ,
130- height : Optional [ int ] = None ,
131- width : Optional [ int ] = None ,
129+ height : int | None = None ,
130+ width : int | None = None ,
132131 ) -> torch .Tensor :
133132 """Preprocess a torch tensor."""
134133 # Determine target dimensions
@@ -158,9 +157,10 @@ def _preprocess_tensor(
158157 else : # (H, W, C) - need to rearrange
159158 image = image .permute (2 , 0 , 1 ).unsqueeze (0 ) # (1, C, H, W)
160159
161- image = torch .nn .functional .interpolate (
162- image , size = (height , width ), mode = "bilinear" , align_corners = False
163- )
160+ image = torch .nn .functional .interpolate (image ,
161+ size = (height , width ),
162+ mode = "bilinear" ,
163+ align_corners = False )
164164
165165 if image .max () > 1.0 : # Assume [0, 255] range
166166 image = image / 255.0
@@ -181,9 +181,11 @@ def _normalize_to_tensor(self, image_np: np.ndarray) -> torch.Tensor:
181181 """
182182 # Convert to tensor
183183 if image_np .ndim == 2 : # (H, W) - grayscale
184- tensor = torch .from_numpy (image_np ).unsqueeze (0 ).unsqueeze (0 ) # (1, 1, H, W)
184+ tensor = torch .from_numpy (image_np ).unsqueeze (0 ).unsqueeze (
185+ 0 ) # (1, 1, H, W)
185186 elif image_np .ndim == 3 : # (H, W, C)
186- tensor = torch .from_numpy (image_np ).permute (2 , 0 , 1 ).unsqueeze (0 ) # (1, C, H, W)
187+ tensor = torch .from_numpy (image_np ).permute (2 , 0 , 1 ).unsqueeze (
188+ 0 ) # (1, C, H, W)
187189 else :
188190 raise ValueError (f"Expected 2D or 3D array, got { image_np .ndim } D" )
189191
0 commit comments