33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6- import unittest
7-
86from typing import Any , Dict , List , Tuple
97
108import numpy as np
119import PIL
10+ import pytest
1211import torch
1312
1413# Import these first. Otherwise, the custom ops are not registered.
3837from torchtune .modules .transforms .vision_utils .get_inscribed_size import (
3938 get_inscribed_size ,
4039)
40+
4141from torchvision .transforms .v2 import functional as F
4242
4343
44- class TestImageTransform (unittest .TestCase ):
44+ def initialize_models (resize_to_max_canvas : bool ) -> Dict [str , Any ]:
45+ config = PreprocessConfig (resize_to_max_canvas = resize_to_max_canvas )
46+
47+ reference_model = CLIPImageTransform (
48+ image_mean = config .image_mean ,
49+ image_std = config .image_std ,
50+ resample = config .resample ,
51+ antialias = config .antialias ,
52+ tile_size = config .tile_size ,
53+ max_num_tiles = config .max_num_tiles ,
54+ resize_to_max_canvas = config .resize_to_max_canvas ,
55+ possible_resolutions = None ,
56+ )
57+
58+ model = CLIPImageTransformModel (config )
59+
60+ exported_model = torch .export .export (
61+ model .get_eager_model (),
62+ model .get_example_inputs (),
63+ dynamic_shapes = model .get_dynamic_shapes (),
64+ strict = False ,
65+ )
66+
67+ # aoti_path = torch._inductor.aot_compile(
68+ # exported_model.module(),
69+ # model.get_example_inputs(),
70+ # )
71+
72+ edge_program = to_edge (
73+ exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
74+ )
75+ executorch_model = edge_program .to_executorch ()
76+
77+ return {
78+ "config" : config ,
79+ "reference_model" : reference_model ,
80+ "model" : model ,
81+ "exported_model" : exported_model ,
82+ # "aoti_path": aoti_path,
83+ "executorch_model" : executorch_model ,
84+ }
85+
86+
87+ # From https://github.com/pytorch/torchtune/blob/main/tests/test_utils.py#L231
88+ def assert_expected (
89+ actual : Any ,
90+ expected : Any ,
91+ rtol : float = 1e-5 ,
92+ atol : float = 1e-8 ,
93+ check_device : bool = True ,
94+ ):
95+ torch .testing .assert_close (
96+ actual ,
97+ expected ,
98+ rtol = rtol ,
99+ atol = atol ,
100+ check_device = check_device ,
101+ msg = f"actual: { actual } , expected: { expected } " ,
102+ )
103+
104+
105+ class TestImageTransform :
45106 """
46- This unittest checks that the exported image transform model produces the
107+ This test checks that the exported image transform model produces the
47108 same output as the reference model.
48109
49110 Reference model: CLIPImageTransform
@@ -52,59 +113,11 @@ class TestImageTransform(unittest.TestCase):
52113 https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26
53114 """
54115
55- @staticmethod
56- def initialize_models (resize_to_max_canvas : bool ) -> Dict [str , Any ]:
57- config = PreprocessConfig (resize_to_max_canvas = resize_to_max_canvas )
58-
59- reference_model = CLIPImageTransform (
60- image_mean = config .image_mean ,
61- image_std = config .image_std ,
62- resample = config .resample ,
63- antialias = config .antialias ,
64- tile_size = config .tile_size ,
65- max_num_tiles = config .max_num_tiles ,
66- resize_to_max_canvas = config .resize_to_max_canvas ,
67- possible_resolutions = None ,
68- )
69-
70- model = CLIPImageTransformModel (config )
71-
72- exported_model = torch .export .export (
73- model .get_eager_model (),
74- model .get_example_inputs (),
75- dynamic_shapes = model .get_dynamic_shapes (),
76- strict = False ,
77- )
116+ models_no_resize = initialize_models (resize_to_max_canvas = False )
117+ models_resize = initialize_models (resize_to_max_canvas = True )
78118
79- # aoti_path = torch._inductor.aot_compile(
80- # exported_model.module(),
81- # model.get_example_inputs(),
82- # )
83-
84- edge_program = to_edge (
85- exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
86- )
87- executorch_model = edge_program .to_executorch ()
88-
89- return {
90- "config" : config ,
91- "reference_model" : reference_model ,
92- "model" : model ,
93- "exported_model" : exported_model ,
94- # "aoti_path": aoti_path,
95- "executorch_model" : executorch_model ,
96- }
97-
98- @classmethod
99- def setUpClass (cls ):
100- cls .models_no_resize = TestImageTransform .initialize_models (
101- resize_to_max_canvas = False
102- )
103- cls .models_resize = TestImageTransform .initialize_models (
104- resize_to_max_canvas = True
105- )
106-
107- def setUp (self ):
119+ @pytest .fixture (autouse = True )
120+ def setup_function (self ):
108121 np .random .seed (0 )
109122
110123 def prepare_inputs (
@@ -185,23 +198,32 @@ def run_preprocess(
185198 reference_ar = reference_output ["aspect_ratio" ].tolist ()
186199
187200 # Check output shape and aspect ratio matches expected values.
188- self .assertEqual (reference_image .shape , expected_shape )
189- self .assertEqual (reference_ar , expected_ar )
201+ assert (
202+ reference_image .shape == expected_shape
203+ ), f"Expected shape { expected_shape } but got { reference_image .shape } "
204+
205+ assert (
206+ reference_ar == expected_ar
207+ ), f"Expected ar { reference_ar } but got { expected_ar } "
190208
191209 # Check pixel values within expected range [0, 1]
192- self .assertTrue (0 <= reference_image .min () <= reference_image .max () <= 1 )
210+ assert (
211+ 0 <= reference_image .min () <= reference_image .max () <= 1
212+ ), f"Expected pixel values in range [0, 1] but got { reference_image .min ()} to { reference_image .max ()} "
193213
194214 # Check mean, max, and min values of the tiles match expected values.
195215 for i , tile in enumerate (reference_image ):
196- self . assertAlmostEqual (
197- tile .mean ().item (), expected_tile_means [i ], delta = 1e-4
216+ assert_expected (
217+ tile .mean ().item (), expected_tile_means [i ], rtol = 0 , atol = 1e-4
198218 )
199- self . assertAlmostEqual (tile .max ().item (), expected_tile_max [i ], delta = 1e-4 )
200- self . assertAlmostEqual (tile .min ().item (), expected_tile_min [i ], delta = 1e-4 )
219+ assert_expected (tile .max ().item (), expected_tile_max [i ], rtol = 0 , atol = 1e-4 )
220+ assert_expected (tile .min ().item (), expected_tile_min [i ], rtol = 0 , atol = 1e-4 )
201221
202222 # Check num tiles matches the product of the aspect ratio.
203223 expected_num_tiles = reference_ar [0 ] * reference_ar [1 ]
204- self .assertEqual (expected_num_tiles , reference_image .shape [0 ])
224+ assert (
225+ expected_num_tiles == reference_image .shape [0 ]
226+ ), f"Expected { expected_num_tiles } tiles but got { reference_image .shape [0 ]} "
205227
206228 # Pre-work for eager and exported models. The reference model performs these
207229 # calculations and passes the result to _CLIPImageTransform, the exportable model.
@@ -215,26 +237,32 @@ def run_preprocess(
215237 image_tensor , inscribed_size , best_resolution
216238 )
217239 eager_ar = eager_ar .tolist ()
218- self .assertTrue (torch .allclose (reference_image , eager_image ))
219- self .assertEqual (reference_ar , eager_ar )
240+ assert torch .allclose (reference_image , eager_image )
241+ assert (
242+ reference_ar == eager_ar
243+ ), f"Eager model: expected { reference_ar } but got { eager_ar } "
220244
221245 # Run exported model and check it matches reference model.
222246 exported_model = models ["exported_model" ]
223247 exported_image , exported_ar = exported_model .module ()(
224248 image_tensor , inscribed_size , best_resolution
225249 )
226250 exported_ar = exported_ar .tolist ()
227- self .assertTrue (torch .allclose (reference_image , exported_image ))
228- self .assertEqual (reference_ar , exported_ar )
251+ assert torch .allclose (reference_image , exported_image )
252+ assert (
253+ reference_ar == exported_ar
254+ ), f"Exported model: expected { reference_ar } but got { exported_ar } "
229255
230256 # Run executorch model and check it matches reference model.
231257 executorch_model = models ["executorch_model" ]
232258 executorch_module = _load_for_executorch_from_buffer (executorch_model .buffer )
233259 et_image , et_ar = executorch_module .forward (
234260 (image_tensor , inscribed_size , best_resolution )
235261 )
236- self .assertTrue (torch .allclose (reference_image , et_image ))
237- self .assertEqual (reference_ar , et_ar .tolist ())
262+ assert torch .allclose (reference_image , et_image )
263+ assert (
264+ reference_ar == et_ar .tolist ()
265+ ), f"Executorch model: expected { reference_ar } but got { et_ar .tolist ()} "
238266
239267 # Run aoti model and check it matches reference model.
240268 # aoti_path = models["aoti_path"]
0 commit comments