Skip to content

Commit 48872fd

Browse files
authored
Add Image Processor Fast RT-DETR (#34354)
* add fast image processor rtdetr * add gpu/cpu test and fix docstring * remove prints * add to doc * nit docstring * avoid iterating over images/annotations several times * change torch typing * Add image processor fast documentation
1 parent 9f06fb0 commit 48872fd

File tree

12 files changed

+1259
-325
lines changed

12 files changed

+1259
-325
lines changed

docs/source/en/main_classes/image_processor.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,49 @@ rendered properly in your Markdown viewer.
1818

1919
An image processor is in charge of preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to PyTorch, TensorFlow, Flax and Numpy tensors. It may also include model specific post-processing such as converting logits to segmentation masks.
2020

21+
Fast image processors are available for a few models and more will be added in the future. They are based on the [torchvision](https://pytorch.org/vision/stable/index.html) library and provide a significant speed-up, especially when processing on GPU.
22+
They have the same API as the base image processors and can be used as drop-in replacements.
23+
To use a fast image processor, you need to install the `torchvision` library, and set the `use_fast` argument to `True` when instantiating the image processor:
24+
25+
```python
26+
from transformers import AutoImageProcessor
27+
28+
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True)
29+
```
30+
31+
When using a fast image processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise.
32+
33+
```python
34+
from torchvision.io import read_image
35+
from transformers import DetrImageProcessorFast
36+
37+
images = read_image("image.jpg")
38+
processor = DetrImageProcessorFast.from_pretrained("facebook/detr-resnet-50")
39+
images_processed = processor(images, return_tensors="pt", device="cuda")
40+
```
41+
42+
Here are some speed comparisons between the base and fast image processors for the `DETR` and `RT-DETR` models, and how they impact overall inference time:
43+
44+
<div class="flex">
45+
<div>
46+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_padded.png" />
47+
</div>
48+
<div>
49+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_batched_compiled.png" />
50+
</div>
51+
</div>
52+
53+
<div class="flex">
54+
<div>
55+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_single.png" />
56+
</div>
57+
<div>
58+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_batched.png" />
59+
</div>
60+
</div>
61+
62+
These benchmarks were run on an [AWS EC2 g5.2xlarge instance](https://aws.amazon.com/ec2/instance-types/g5/), utilizing an NVIDIA A10G Tensor Core GPU.
63+
2164

2265
## ImageProcessingMixin
2366

docs/source/en/model_doc/rt_detr.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Initially, an image is processed using a pre-trained convolutional neural networ
4646
>>> from PIL import Image
4747
>>> from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
4848

49-
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
49+
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
5050
>>> image = Image.open(requests.get(url, stream=True).raw)
5151

5252
>>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
@@ -95,6 +95,12 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
9595
- preprocess
9696
- post_process_object_detection
9797

98+
## RTDetrImageProcessorFast
99+
100+
[[autodoc]] RTDetrImageProcessorFast
101+
- preprocess
102+
- post_process_object_detection
103+
98104
## RTDetrModel
99105

100106
[[autodoc]] RTDetrModel

src/transformers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@
12281228
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
12291229
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
12301230
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
1231-
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
1231+
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor", "RTDetrImageProcessorFast"])
12321232
_import_structure["models.sam"].extend(["SamImageProcessor"])
12331233
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
12341234
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
@@ -6152,7 +6152,7 @@
61526152
)
61536153
from .models.pvt import PvtImageProcessor
61546154
from .models.qwen2_vl import Qwen2VLImageProcessor
6155-
from .models.rt_detr import RTDetrImageProcessor
6155+
from .models.rt_detr import RTDetrImageProcessor, RTDetrImageProcessorFast
61566156
from .models.sam import SamImageProcessor
61576157
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
61586158
from .models.seggpt import SegGptImageProcessor

src/transformers/image_processing_utils_fast.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515

1616
import functools
1717
from dataclasses import dataclass
18+
from typing import Any, Iterable, List, Optional, Tuple
1819

1920
from .image_processing_utils import BaseImageProcessor
20-
from .utils.import_utils import is_torchvision_available
21+
from .utils.import_utils import is_torch_available, is_torchvision_available
2122

2223

2324
if is_torchvision_available():
2425
from torchvision.transforms import Compose
2526

27+
if is_torch_available():
28+
import torch
29+
2630

2731
@dataclass(frozen=True)
2832
class SizeDict:
@@ -66,3 +70,64 @@ def to_dict(self):
6670
encoder_dict = super().to_dict()
6771
encoder_dict.pop("_transform_params", None)
6872
return encoder_dict
73+
74+
75+
def get_image_size_for_max_height_width(
76+
image_size: Tuple[int, int],
77+
max_height: int,
78+
max_width: int,
79+
) -> Tuple[int, int]:
80+
"""
81+
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
82+
Important, even if image_height < max_height and image_width < max_width, the image will be resized
83+
to at least one of the edges be equal to max_height or max_width.
84+
85+
For example:
86+
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
87+
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
88+
89+
Args:
90+
image_size (`Tuple[int, int]`):
91+
The image to resize.
92+
max_height (`int`):
93+
The maximum allowed height.
94+
max_width (`int`):
95+
The maximum allowed width.
96+
"""
97+
height, width = image_size
98+
height_scale = max_height / height
99+
width_scale = max_width / width
100+
min_scale = min(height_scale, width_scale)
101+
new_height = int(height * min_scale)
102+
new_width = int(width * min_scale)
103+
return new_height, new_width
104+
105+
106+
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
107+
"""
108+
Squeezes a tensor, but only if the axis specified has dim 1.
109+
"""
110+
if axis is None:
111+
return tensor.squeeze()
112+
113+
try:
114+
return tensor.squeeze(axis=axis)
115+
except ValueError:
116+
return tensor
117+
118+
119+
def max_across_indices(values: Iterable[Any]) -> List[Any]:
120+
"""
121+
Return the maximum value across all indices of an iterable of values.
122+
"""
123+
return [max(values_i) for values_i in zip(*values)]
124+
125+
126+
def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]:
127+
"""
128+
Get the maximum height and width across all images in a batch.
129+
"""
130+
131+
_, max_height, max_width = max_across_indices([img.shape for img in images])
132+
133+
return (max_height, max_width)

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
("qwen2_vl", ("Qwen2VLImageProcessor",)),
124124
("regnet", ("ConvNextImageProcessor",)),
125125
("resnet", ("ConvNextImageProcessor",)),
126-
("rt_detr", "RTDetrImageProcessor"),
126+
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
127127
("sam", ("SamImageProcessor",)),
128128
("segformer", ("SegformerImageProcessor",)),
129129
("seggpt", ("SegGptImageProcessor",)),

0 commit comments

Comments
 (0)