Skip to content

Commit 684e256

Browse files
fix: Add prediction heads and tests for classification architectures
- Add prediction head (dense, activation, norm) to sequence classification - Add prediction head to token classification - Add attention_mask to input_template for sequence classification - Add tests for sequence_classification and token_classification
1 parent 3db832e commit 684e256

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

lib/bumblebee/text/modernbert.ex

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ defmodule Bumblebee.Text.ModernBert do
159159
end
160160

161161
@impl true
162+
def input_template(%{architecture: :for_sequence_classification}) do
163+
%{
164+
"input_ids" => Nx.template({1, 1}, :u32),
165+
"attention_mask" => Nx.template({1, 1}, :u32)
166+
}
167+
end
168+
162169
def input_template(_spec) do
163170
%{"input_ids" => Nx.template({1, 1}, :u32)}
164171
end
@@ -193,6 +200,7 @@ defmodule Bumblebee.Text.ModernBert do
193200
outputs.hidden_state
194201
|> mean_pooling(inputs["attention_mask"])
195202
|> Axon.dense(spec.hidden_size,
203+
use_bias: false,
196204
kernel_initializer: kernel_initializer(spec),
197205
name: "sequence_classification_head.dense"
198206
)
@@ -223,6 +231,16 @@ defmodule Bumblebee.Text.ModernBert do
223231

224232
logits =
225233
outputs.hidden_state
234+
|> Axon.dense(spec.hidden_size,
235+
use_bias: false,
236+
kernel_initializer: kernel_initializer(spec),
237+
name: "token_classification_head.dense"
238+
)
239+
|> Layers.activation(spec.activation)
240+
|> layer_norm(
241+
epsilon: spec.layer_norm_epsilon,
242+
name: "token_classification_head.norm"
243+
)
226244
|> Axon.dropout(
227245
rate: classifier_dropout_rate(spec),
228246
name: "token_classification_head.dropout"
@@ -564,6 +582,8 @@ defmodule Bumblebee.Text.ModernBert do
564582
"sequence_classification_head.dense" => "head.dense",
565583
"sequence_classification_head.norm" => "head.norm",
566584
"sequence_classification_head.output" => "classifier",
585+
"token_classification_head.dense" => "head.dense",
586+
"token_classification_head.norm" => "head.norm",
567587
"token_classification_head.output" => "classifier"
568588
}
569589
end

test/bumblebee/text/modernbert_test.exs

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@ defmodule Bumblebee.Text.ModernBertTest do
55

66
@moduletag model_test_tags()
77

8-
# Note: sequence_classification and token_classification tests are skipped
9-
# because the tiny-random test models have incompatible head structures.
10-
# The architectures work correctly with production models.
11-
128
test ":base" do
139
assert {:ok, %{model: model, params: params, spec: spec}} =
1410
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ModernBertModel"})
@@ -54,4 +50,52 @@ defmodule Bumblebee.Text.ModernBertTest do
5450
])
5551
)
5652
end
53+
54+
test ":for_sequence_classification" do
55+
assert {:ok, %{model: model, params: params, spec: spec}} =
56+
Bumblebee.load_model(
57+
{:hf, "hf-internal-testing/tiny-random-ModernBertForSequenceClassification"}
58+
)
59+
60+
assert %Bumblebee.Text.ModernBert{architecture: :for_sequence_classification} = spec
61+
62+
inputs = %{
63+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
64+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
65+
}
66+
67+
outputs = Axon.predict(model, params, inputs)
68+
69+
assert Nx.shape(outputs.logits) == {1, 2}
70+
71+
assert_all_close(
72+
outputs.logits,
73+
Nx.tensor([[1.2857, 2.1079]])
74+
)
75+
end
76+
77+
test ":for_token_classification" do
78+
assert {:ok, %{model: model, params: params, spec: spec}} =
79+
Bumblebee.load_model(
80+
{:hf, "hf-internal-testing/tiny-random-ModernBertForTokenClassification"}
81+
)
82+
83+
assert %Bumblebee.Text.ModernBert{architecture: :for_token_classification} = spec
84+
85+
inputs = %{
86+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
87+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
88+
}
89+
90+
outputs = Axon.predict(model, params, inputs)
91+
92+
assert Nx.shape(outputs.logits) == {1, 10, 2}
93+
94+
assert_all_close(
95+
outputs.logits[[.., 1..3, ..]],
96+
Nx.tensor([
97+
[[5.0522, -0.8999], [-3.2701, 1.8927], [-0.7372, 5.4871]]
98+
])
99+
)
100+
end
57101
end

0 commit comments

Comments
 (0)