Skip to content

Commit 812e3d5

Browse files
committed
add depth anything
Signed-off-by: Phillip Kuznetsov <[email protected]>
1 parent 661d5b7 commit 812e3d5

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

src/transformers/models/depth_anything/modeling_depth_anything.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,16 @@ def forward(self, hidden_states, size=None):
224224
hidden_states = hidden_states[::-1]
225225

226226
fused_hidden_states = []
227-
# first layer only uses the last hidden_state
228-
size = hidden_states[1].shape[2:]
229-
fused_hidden_state = self.layers[0](hidden_states[0], size=size)
230-
fused_hidden_states.append(fused_hidden_state)
227+
fused_hidden_state = None
231228

232-
# looping from the last layer to the second
233-
for idx, (hidden_state, layer) in enumerate(zip(hidden_states[1:], self.layers[1:])):
234-
size = hidden_states[1:][idx + 1].shape[2:] if idx != (len(hidden_states[1:]) - 1) else None
229+
for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
230+
size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
235231

236-
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
232+
if fused_hidden_state is None:
233+
# first layer only uses the last hidden_state
234+
fused_hidden_state = layer(hidden_state, size=size)
235+
else:
236+
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
237237

238238
fused_hidden_states.append(fused_hidden_state)
239239

tests/models/depth_anything/test_modeling_depth_anything.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from transformers import DepthAnythingConfig, Dinov2Config
2020
from transformers.file_utils import is_torch_available, is_vision_available
21+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
2122
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
2223

2324
from ...test_configuration_common import ConfigTester
@@ -290,3 +291,30 @@ def test_inference(self):
290291
).to(torch_device)
291292

292293
self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
294+
295+
def test_export(self):
296+
for strict in [True, False]:
297+
with self.subTest(strict=strict):
298+
if not is_torch_greater_or_equal_than_2_4:
299+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
300+
model = (
301+
DepthAnythingForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
302+
.to(torch_device)
303+
.eval()
304+
)
305+
image_processor = DPTImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
306+
image = prepare_img()
307+
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
308+
309+
exported_program = torch.export.export(
310+
model,
311+
args=(inputs["pixel_values"],),
312+
strict=strict,
313+
)
314+
with torch.no_grad():
315+
eager_outputs = model(**inputs)
316+
exported_outputs = exported_program.module().forward(inputs["pixel_values"])
317+
self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape)
318+
self.assertTrue(
319+
torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4)
320+
)

0 commit comments

Comments
 (0)