|
20 | 20 |
|
21 | 21 | from transformers import Dinov2Config, ZoeDepthConfig |
22 | 22 | from transformers.file_utils import is_torch_available, is_vision_available |
| 23 | +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 |
23 | 24 | from transformers.testing_utils import require_torch, require_vision, slow, torch_device |
24 | 25 |
|
25 | 26 | from ...test_configuration_common import ConfigTester |
@@ -354,3 +355,29 @@ def test_inference_depth_estimation_post_processing_pad_flip(self): |
354 | 355 | model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device) |
355 | 356 |
|
356 | 357 | self.check_post_processing_test(image_processor, images, model, pad_input=True, flip_aug=True) |
| 358 | + |
| 359 | + def test_export(self): |
| 360 | + self.skipTest( |
| 361 | + reason="This test fails because the beit backbone of ZoeDepth is not compatible with torch.export" |
| 362 | + ) |
| 363 | + for strict in [True, False]: |
| 364 | + with self.subTest(strict=True): |
| 365 | + if not is_torch_greater_or_equal_than_2_4: |
| 366 | + self.skipTest(reason="This test requires torch >= 2.4 to run.") |
| 367 | + model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu").to(torch_device).eval() |
| 368 | + image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu") |
| 369 | + image = prepare_img() |
| 370 | + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) |
| 371 | + |
| 372 | + exported_program = torch.export.export( |
| 373 | + model, |
| 374 | + args=(inputs["pixel_values"],), |
| 375 | + strict=strict, |
| 376 | + ) |
| 377 | + with torch.no_grad(): |
| 378 | + eager_outputs = model(**inputs) |
| 379 | + exported_outputs = exported_program.module().forward(inputs["pixel_values"]) |
| 380 | + self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape) |
| 381 | + self.assertTrue( |
| 382 | + torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4) |
| 383 | + ) |
0 commit comments