diff --git a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py index 7027a0b77a3c..3be6613e0c79 100644 --- a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py +++ b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py @@ -16,26 +16,16 @@ import unittest import requests +import torch from datasets import load_dataset +from PIL import Image +from transformers import MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs -if is_torch_available(): - import torch - -if is_vision_available(): - from PIL import Image - - from transformers import MobileNetV2ImageProcessor - - if is_torchvision_available(): - from transformers import MobileNetV2ImageProcessorFast - - class MobileNetV2ImageProcessingTester: def __init__( self, @@ -103,9 +93,6 @@ def prepare_semantic_batch_inputs(): @require_torch @require_vision class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): - image_processing_class = MobileNetV2ImageProcessor if is_vision_available() else None - fast_image_processing_class = MobileNetV2ImageProcessorFast if is_torchvision_available() else None - def setUp(self): super().setUp() self.image_processor_tester = MobileNetV2ImageProcessingTester(self) @@ -115,7 +102,7 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]: image_processor = image_processing_class(**self.image_processor_dict) self.assertTrue(hasattr(image_processor, "do_resize")) self.assertTrue(hasattr(image_processor, "size")) @@ -124,7 +111,7 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processor, "do_reduce_labels")) def test_image_processor_from_dict_with_kwargs(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]: image_processor = image_processing_class.from_dict(self.image_processor_dict) self.assertEqual(image_processor.size, {"shortest_edge": 20}) self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) @@ -139,7 +126,7 @@ def test_image_processor_from_dict_with_kwargs(self): def test_call_segmentation_maps(self): # Initialize image_processing - for image_processing_class in self.image_processor_list: + for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]: image_processing = image_processing_class(**self.image_processor_dict) # create random PyTorch tensors image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) @@ -246,7 +233,7 @@ def test_call_segmentation_maps(self): def test_reduce_labels(self): # Initialize image_processing - for image_processing_class in self.image_processor_list: + for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]: image_processing = image_processing_class(**self.image_processor_dict) # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 @@ -260,19 +247,12 @@ def test_reduce_labels(self): self.assertTrue(encoding["labels"].min().item() >= 0) self.assertTrue(encoding["labels"].max().item() <= 255) - def test_slow_fast_equivalence(self): - if not self.test_slow_image_processor or not self.test_fast_image_processor: - self.skipTest(reason="Skipping slow/fast equivalence test") - - if self.image_processing_class is None or self.fast_image_processing_class is None: - self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") - # Test with single image dummy_image = Image.open( requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw ) - image_processor_slow = self.image_processing_class(**self.image_processor_dict) - image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + image_processor_slow = MobileNetV2ImageProcessor(**self.image_processor_dict) + image_processor_fast = MobileNetV2ImageProcessorFast(**self.image_processor_dict) encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")