diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index fb2cc005d7..b339a8cd52 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -60,12 +60,21 @@ def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]: return batches - def sample_transforms(self, x: np.ndarray) -> torch.Tensor: - if x.ndim != 3: - raise AssertionError("expected list of 3D Tensors") - if x.dtype not in (np.uint8, np.float32, np.float16): - raise TypeError("unsupported data type for numpy.ndarray") - tensor = torch.from_numpy(x.copy()).permute(2, 0, 1) + def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor: + if isinstance(x, np.ndarray): + if x.ndim != 3: + raise AssertionError("expected list of 3D Tensors") + if x.dtype not in (np.uint8, np.float32, np.float16): + raise TypeError("unsupported data type for numpy.ndarray") + x = torch.from_numpy(x.copy()) + elif isinstance(x, torch.Tensor): + if x.ndim != 3: + raise AssertionError("expected 3D Tensor") + else: + raise TypeError(f"invalid input type: {type(x)}") + + tensor = x.permute(2, 0, 1) + # Resizing tensor = self.resize(tensor) # Data type @@ -76,22 +85,28 @@ def sample_transforms(self, x: np.ndarray) -> torch.Tensor: return tensor - def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]: + def __call__(self, x: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor]) -> list[torch.Tensor]: """Prepare document data for model forwarding Args: - x: list of images (np.array) or a single image (np.array) of shape (H, W, C) + x: list of images (np.array | torch.Tensor) or a single image (np.array | torch.Tensor) of shape (H, W, C) Returns: list of page batches (*, C, H, W) ready for model inference """ # Input type check - if isinstance(x, np.ndarray): + if isinstance(x, np.ndarray) or isinstance(x, torch.Tensor): + if x.ndim != 4: raise AssertionError("expected 4D Tensor") - if x.dtype not in (np.uint8, np.float32, np.float16): + if x.dtype not in (np.uint8, np.float32, np.float16) and isinstance(x, np.ndarray): raise TypeError("unsupported data type for numpy.ndarray") - tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) + if x.dtype not in (torch.uint8, torch.float32, torch.float16) and isinstance(x, torch.Tensor): + raise TypeError("unsupported data type for torch.Tensor") + if isinstance(x, np.ndarray): + x = torch.from_numpy(x.copy()) + + tensor = x.permute(0, 3, 1, 2) # Resizing if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]: @@ -105,7 +120,7 @@ def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]: tensor = tensor.to(dtype=torch.float32) batches = [tensor] - elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x): + elif isinstance(x, list) and all(isinstance(sample, np.ndarray) or isinstance(sample, torch.Tensor) for sample in x): # Sample transform (to tensor, resize) samples = list(multithread_exec(self.sample_transforms, x)) # Batching