Skip to content

Commit a07b5e9

Browse files
MilkCloudsyonigozlanCyrilvallez
authored
feat: add is_fast to ImageProcessor (#39603)
* feat: add `is_fast` to ImageProcessor * test_image_processing_common.py 업데이트 Co-authored-by: Yoni Gozlan <[email protected]> * feat: add missing BaseImageProcessorFast import * fix: `issubclass` for discriminating subclass of BaseImageProcessorFast --------- Co-authored-by: Yoni Gozlan <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
1 parent 952fac1 commit a07b5e9

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

src/transformers/image_processing_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ class BaseImageProcessor(ImageProcessingMixin):
3939
def __init__(self, **kwargs):
4040
super().__init__(**kwargs)
4141

42+
@property
43+
def is_fast(self) -> bool:
44+
"""
45+
`bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
46+
"""
47+
return False
48+
4249
def __call__(self, images, **kwargs) -> BatchFeature:
4350
"""Preprocess an image or a batch of images."""
4451
return self.preprocess(images, **kwargs)

src/transformers/image_processing_utils_fast.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ def __init__(
235235
# get valid kwargs names
236236
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
237237

238+
@property
239+
def is_fast(self) -> bool:
240+
"""
241+
`bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
242+
"""
243+
return True
244+
238245
def resize(
239246
self,
240247
image: "torch.Tensor",

tests/test_image_processing_common.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
slow,
3737
torch_device,
3838
)
39-
from transformers.utils import is_torch_available, is_vision_available
39+
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
40+
41+
42+
if is_torchvision_available():
43+
from transformers.image_processing_utils_fast import BaseImageProcessorFast
4044

4145

4246
if is_torch_available():
@@ -241,6 +245,16 @@ def measure_time(image_processor, image):
241245

242246
self.assertLessEqual(fast_time, slow_time)
243247

248+
def test_is_fast(self):
249+
for image_processing_class in self.image_processor_list:
250+
image_processor = image_processing_class(**self.image_processor_dict)
251+
252+
# Check is_fast is set correctly
253+
if is_torchvision_available() and issubclass(image_processing_class, BaseImageProcessorFast):
254+
self.assertTrue(image_processor.is_fast)
255+
else:
256+
self.assertFalse(image_processor.is_fast)
257+
244258
def test_image_processor_to_json_string(self):
245259
for image_processing_class in self.image_processor_list:
246260
image_processor = image_processing_class(**self.image_processor_dict)

0 commit comments

Comments
 (0)