Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 9 additions & 29 deletions tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,16 @@
import unittest

import requests
import torch
from datasets import load_dataset
from PIL import Image

from transformers import MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs


if is_torch_available():
import torch

if is_vision_available():
from PIL import Image

from transformers import MobileNetV2ImageProcessor

if is_torchvision_available():
from transformers import MobileNetV2ImageProcessorFast


class MobileNetV2ImageProcessingTester:
def __init__(
self,
Expand Down Expand Up @@ -103,9 +93,6 @@ def prepare_semantic_batch_inputs():
@require_torch
@require_vision
class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = MobileNetV2ImageProcessor if is_vision_available() else None
fast_image_processing_class = MobileNetV2ImageProcessorFast if is_torchvision_available() else None

def setUp(self):
super().setUp()
self.image_processor_tester = MobileNetV2ImageProcessingTester(self)
Expand All @@ -115,7 +102,7 @@ def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
Expand All @@ -124,7 +111,7 @@ def test_image_processor_properties(self):
self.assertTrue(hasattr(image_processor, "do_reduce_labels"))

def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
Expand All @@ -139,7 +126,7 @@ def test_image_processor_from_dict_with_kwargs(self):

def test_call_segmentation_maps(self):
# Initialize image_processing
for image_processing_class in self.image_processor_list:
for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]:
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
Expand Down Expand Up @@ -246,7 +233,7 @@ def test_call_segmentation_maps(self):

def test_reduce_labels(self):
# Initialize image_processing
for image_processing_class in self.image_processor_list:
for image_processing_class in [MobileNetV2ImageProcessor, MobileNetV2ImageProcessorFast]:
image_processing = image_processing_class(**self.image_processor_dict)

# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
Expand All @@ -260,19 +247,12 @@ def test_reduce_labels(self):
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)

def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")

if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")

# Test with single image
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
image_processor_slow = MobileNetV2ImageProcessor(**self.image_processor_dict)
image_processor_fast = MobileNetV2ImageProcessorFast(**self.image_processor_dict)

encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
Expand Down