@@ -54,6 +54,7 @@ class PreprocessConfig:
5454 tile_size : int = 224
5555 max_num_tiles : int = 4
5656 possible_resolutions = None
57+ pad_max_tiles : bool = True
5758
5859
5960class TestImageTransform (unittest .TestCase ):
@@ -136,6 +137,17 @@ def prepare_inputs(
136137 [1.0 , 1.0 ], # expected_tile_max
137138 [0.0 , 0.0 ], # expected_tile_min
138139 [1 , 2 ], # expected_aspect_ratio
140+ False , # pad_max_tiles
141+ ),
142+ (
143+ (100 , 400 , 3 ), # image_size
144+ torch .Size ([4 , 3 , 224 , 224 ]), # expected shape
145+ False , # resize_to_max_canvas
146+ [0.2230 , 0.1763 , 0.0 , 0.0 ], # expected_tile_means
147+ [1.0 , 1.0 , 0.0 , 0.0 ], # expected_tile_max
148+ [0.0 , 0.0 , 0.0 , 0.0 ], # expected_tile_min
149+ [1 , 2 ], # expected_aspect_ratio
150+ True , # pad_max_tiles
139151 ),
140152 (
141153 (1000 , 300 , 3 ), # image_size
@@ -145,6 +157,7 @@ def prepare_inputs(
145157 [0.9976 , 0.9940 , 0.9936 , 0.9906 ], # expected_tile_max
146158 [0.0037 , 0.0047 , 0.0039 , 0.0 ], # expected_tile_min
147159 [4 , 1 ], # expected_aspect_ratio
160+ False , # pad_max_tiles
148161 ),
149162 (
150163 (200 , 200 , 3 ), # image_size
@@ -154,6 +167,7 @@ def prepare_inputs(
154167 [0.9921 , 0.9925 , 0.9969 , 0.9908 ], # expected_tile_max
155168 [0.0056 , 0.0069 , 0.0059 , 0.0032 ], # expected_tile_min
156169 [2 , 2 ], # expected_aspect_ratio
170+ False , # pad_max_tiles
157171 ),
158172 (
159173 (600 , 200 , 3 ), # image_size
@@ -163,6 +177,17 @@ def prepare_inputs(
163177 [1.0 , 1.0 , 1.0 ], # expected_tile_max
164178 [0.0 , 0.0 , 0.0 ], # expected_tile_min
165179 [3 , 1 ], # expected_aspect_ratio
180+ False , # pad_max_tiles
181+ ),
182+ (
183+ (600 , 200 , 3 ), # image_size
184+ torch .Size ([4 , 3 , 224 , 224 ]), # expected shape
185+ False , # resize_to_max_canvas
186+ [0.4472 , 0.4468 , 0.3031 , 0.0 ], # expected_tile_means
187+ [1.0 , 1.0 , 1.0 , 0.0 ], # expected_tile_max
188+ [0.0 , 0.0 , 0.0 , 0.0 ], # expected_tile_min
189+ [3 , 1 ], # expected_aspect_ratio
190+ True , # pad_max_tiles
166191 ),
167192 ]
168193 )
@@ -175,8 +200,11 @@ def test_preprocess(
175200 expected_tile_max : List [float ],
176201 expected_tile_min : List [float ],
177202 expected_ar : List [int ],
203+ pad_max_tiles : bool ,
178204 ) -> None :
179- config = PreprocessConfig (resize_to_max_canvas = resize_to_max_canvas )
205+ config = PreprocessConfig (
206+ resize_to_max_canvas = resize_to_max_canvas , pad_max_tiles = pad_max_tiles
207+ )
180208
181209 reference_model = CLIPImageTransform (
182210 image_mean = config .image_mean ,
@@ -187,6 +215,7 @@ def test_preprocess(
187215 tile_size = config .tile_size ,
188216 max_num_tiles = config .max_num_tiles ,
189217 possible_resolutions = None ,
218+ pad_max_tiles = config .pad_max_tiles ,
190219 )
191220
192221 eager_model = _CLIPImageTransform (
@@ -196,6 +225,7 @@ def test_preprocess(
196225 antialias = config .antialias ,
197226 tile_size = config .tile_size ,
198227 max_num_tiles = config .max_num_tiles ,
228+ pad_max_tiles = config .pad_max_tiles ,
199229 )
200230
201231 exported_model = export_preprocess (
@@ -205,6 +235,7 @@ def test_preprocess(
205235 antialias = config .antialias ,
206236 tile_size = config .tile_size ,
207237 max_num_tiles = config .max_num_tiles ,
238+ pad_max_tiles = config .pad_max_tiles ,
208239 )
209240
210241 executorch_model = lower_to_executorch_preprocess (exported_model )
@@ -244,8 +275,11 @@ def test_preprocess(
244275 self .assertAlmostEqual (tile .min ().item (), expected_tile_min [i ], delta = 1e-4 )
245276
246277 # Check num tiles matches the product of the aspect ratio.
247- expected_num_tiles = reference_ar [0 ] * reference_ar [1 ]
248- self .assertEqual (expected_num_tiles , reference_image .shape [0 ])
278+ if pad_max_tiles :
279+ self .assertEqual (config .max_num_tiles , reference_image .shape [0 ])
280+ else :
281+ expected_num_tiles = reference_ar [0 ] * reference_ar [1 ]
282+ self .assertEqual (expected_num_tiles , reference_image .shape [0 ])
249283
250284 # Pre-work for eager and exported models. The reference model performs these
251285 # calculations and passes the result to _CLIPImageTransform, the exportable model.
0 commit comments