|
18 | 18 |
|
19 | 19 | from transformers import DepthAnythingConfig, Dinov2Config |
20 | 20 | from transformers.file_utils import is_torch_available, is_vision_available |
| 21 | +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 |
21 | 22 | from transformers.testing_utils import require_torch, require_vision, slow, torch_device |
22 | 23 |
|
23 | 24 | from ...test_configuration_common import ConfigTester |
@@ -290,3 +291,30 @@ def test_inference(self): |
290 | 291 | ).to(torch_device) |
291 | 292 |
|
292 | 293 | self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4)) |
| 294 | + |
| 295 | + def test_export(self): |
| 296 | + for strict in [True, False]: |
| 297 | + with self.subTest(strict=strict): |
| 298 | + if not is_torch_greater_or_equal_than_2_4: |
| 299 | + self.skipTest(reason="This test requires torch >= 2.4 to run.") |
| 300 | + model = ( |
| 301 | + DepthAnythingForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf") |
| 302 | + .to(torch_device) |
| 303 | + .eval() |
| 304 | + ) |
| 305 | + image_processor = DPTImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf") |
| 306 | + image = prepare_img() |
| 307 | + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) |
| 308 | + |
| 309 | + exported_program = torch.export.export( |
| 310 | + model, |
| 311 | + args=(inputs["pixel_values"],), |
| 312 | + strict=strict, |
| 313 | + ) |
| 314 | + with torch.no_grad(): |
| 315 | + eager_outputs = model(**inputs) |
| 316 | + exported_outputs = exported_program.module().forward(inputs["pixel_values"]) |
| 317 | + self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape) |
| 318 | + self.assertTrue( |
| 319 | + torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4) |
| 320 | + ) |
0 commit comments