Skip to content

Commit 55ec9ac

Browse files
Fix DinoV2 crash when batch_size > 1. (#429)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent 8365426 commit 55ec9ac

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

lib/bumblebee/vision/dino_v2.ex

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,22 +312,26 @@ defmodule Bumblebee.Vision.DinoV2 do
312312
Axon.layer(
313313
fn position_embeddings, pixel_values, _opts ->
314314
original_positions = div(spec.image_size, spec.patch_size)
315-
{batch_size, height, width, _channels} = Nx.shape(pixel_values)
315+
{_batch_size, height, width, _channels} = Nx.shape(pixel_values)
316316
resized_height = div(height, spec.patch_size)
317317
resized_width = div(width, spec.patch_size)
318+
position_embeddings_batch_size = Nx.axis_size(position_embeddings, 0)
318319

319320
class_position_embedding = position_embeddings[[.., 0..0//1, ..]]
320321
input_position_embeddings = position_embeddings[[.., 1..-1//1, ..]]
321322

322323
interpolated_position_embeddings =
323324
input_position_embeddings
324-
|> Nx.reshape({batch_size, original_positions, original_positions, spec.hidden_size})
325+
|> Nx.reshape(
326+
{position_embeddings_batch_size, original_positions, original_positions,
327+
spec.hidden_size}
328+
)
325329
|> Axon.Layers.resize(
326330
size: {resized_height, resized_width},
327331
method: :bicubic,
328332
antialias: false
329333
)
330-
|> Nx.reshape({batch_size, :auto, spec.hidden_size})
334+
|> Nx.reshape({position_embeddings_batch_size, :auto, spec.hidden_size})
331335

332336
Nx.concatenate([class_position_embedding, interpolated_position_embeddings], axis: 1)
333337
end,

test/bumblebee/vision/dino_v2_test.exs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,59 @@ defmodule Bumblebee.Vision.DinoV2Test do
6868
)
6969
end
7070

71+
test ":base with batch size > 1 and position embedding interpolation" do
72+
assert {:ok, %{model: model, params: params, spec: spec}} =
73+
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Model"})
74+
75+
assert %Bumblebee.Vision.DinoV2{architecture: :base} = spec
76+
77+
inputs = %{
78+
# simulate a batch of 2 different images
79+
"pixel_values" =>
80+
Nx.stack([
81+
Nx.broadcast(0.5, {64, 64, 3}),
82+
Nx.broadcast(-0.5, {64, 64, 3})
83+
])
84+
}
85+
86+
outputs = Axon.predict(model, params, inputs)
87+
88+
assert Nx.shape(outputs.hidden_state) == {2, 1025, 32}
89+
assert Nx.shape(outputs.pooled_state) == {2, 32}
90+
91+
assert_all_close(
92+
outputs.hidden_state[[0, 1..3, 1..3]],
93+
Nx.tensor([
94+
[-1.2287, -0.2291, -0.4323],
95+
[-1.1548, -0.4430, -0.4710],
96+
[-1.0547, -0.7580, -0.4654]
97+
]),
98+
atol: 1.0e-1
99+
)
100+
101+
assert_all_close(
102+
outputs.pooled_state[[0, 1..3]],
103+
Nx.tensor([-0.7270, -0.5913, 0.7701]),
104+
atol: 1.0e-2
105+
)
106+
107+
assert_all_close(
108+
outputs.hidden_state[[1, 1..3, 1..3]],
109+
Nx.tensor([
110+
[0.5043, 0.9529, 0.8042],
111+
[0.5555, 0.8786, 0.8492],
112+
[0.6114, 0.5701, 0.9278]
113+
]),
114+
atol: 1.0e-1
115+
)
116+
117+
assert_all_close(
118+
outputs.pooled_state[[1, 1..3]],
119+
Nx.tensor([0.1374, -0.9757, 1.1830]),
120+
atol: 1.0e-2
121+
)
122+
end
123+
71124
test ":base with swiglu ffn" do
72125
assert {:ok, %{model: model, params: params, spec: spec}} =
73126
Bumblebee.load_model(

0 commit comments

Comments
 (0)