Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fc0825a
[#SAMPLE-6] Add state to logits processing
joelpaulkoch Oct 17, 2025
01ab3af
stateful logits processors
joelpaulkoch Oct 16, 2025
5413662
adding another test
joelpaulkoch Oct 16, 2025
9d4ef39
fix test so compilation works
joelpaulkoch Oct 20, 2025
4ce01cc
demonstrate stateful logits processor through test assertions
joelpaulkoch Oct 20, 2025
2161b77
independent state for batch entries
joelpaulkoch Oct 20, 2025
fefc9fd
renamed initial_suppressed_token_index for clarity
xhr15 Oct 21, 2025
6e8612a
renamend next_suppressed_index -> :next_suppressed_token_index
xhr15 Oct 21, 2025
e43254a
logits_processor_states -> logits_processor_state in batch tests
xhr15 Oct 21, 2025
a2f0015
added a test with batch size 1 for clarity
xhr15 Oct 21, 2025
0cdc0ad
renaming suppressed_id -> suppressed_token_id
xhr15 Oct 21, 2025
cc6d6e3
more comments
xhr15 Oct 21, 2025
3816e7c
changed to to demonstrate stack functionality
xhr15 Oct 23, 2025
fe58712
merged tests
xhr15 Oct 23, 2025
c97890a
removed test for processor only used in test
xhr15 Oct 23, 2025
fbf5ef3
introduces LogitsProcessor module
xhr15 Oct 24, 2025
dfa223c
clean up
joelpaulkoch Oct 27, 2025
9098bda
mix format
joelpaulkoch Oct 27, 2025
544d80f
vectorized sequences are called sequence in context
joelpaulkoch Oct 27, 2025
2ba5e0a
don't vectorize all the logits processor state
joelpaulkoch Oct 27, 2025
196c8f0
swap {logits, state} to {state, logits}
joelpaulkoch Nov 5, 2025
ee2a01e
rename logits_processor_state to logits_processor_states
joelpaulkoch Nov 5, 2025
3563ff0
states as tuples
joelpaulkoch Nov 5, 2025
6db771e
update test
joelpaulkoch Nov 5, 2025
c8442e0
single initial state for all batch entries
joelpaulkoch Nov 5, 2025
41dd2ad
vectorize sequence for init, derive vectorized state
joelpaulkoch Nov 5, 2025
311f77c
switch to EXLA as the evaluator is lacking vectorisation support:
xhr15 Nov 14, 2025
ec92264
Apply suggestion from @jonatanklosko
xhr15 Nov 14, 2025
201e103
removed comments
xhr15 Nov 14, 2025
70d7f65
slimmed down comments more
xhr15 Nov 14, 2025
ce92584
introduced types for init_context and process_context
xhr15 Nov 14, 2025
6d8f494
don't vectorize initial_enforced_token_id in test as it's the same ov…
xhr15 Nov 14, 2025
578ce11
bonus track: two more livebooks concerning logits processing. Not str…
xhr15 Nov 14, 2025
1f7798f
Update test/bumblebee/text/generation/logits_processing_test.exs
xhr15 Nov 17, 2025
e9d0c78
moving livebooks to separate PR
xhr15 Nov 17, 2025
1da730e
logits_processor.ex aktualisieren
xhr15 Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,12 @@ defmodule Bumblebee.Text.Generation do
end ++ logits_processors

fn logits, context ->
for processor <- processors, processor, reduce: logits do
logits -> processor.(logits, context)
for processor <- processors, processor, reduce: {logits, context} do
{logits, context} ->
case processor.(logits, context) do
{logits, new_context} -> {logits, new_context}
logits -> {logits, context}
end
end
end
end
Expand Down Expand Up @@ -551,7 +555,8 @@ defmodule Bumblebee.Text.Generation do
length: length,
finished_length: finished_length,
# The ignored return value that we attach all hooks to
ignored: Nx.broadcast(0, {batch_size})
ignored: Nx.broadcast(0, {batch_size}),
logits_processor_states: %{}
}
end

Expand All @@ -574,7 +579,7 @@ defmodule Bumblebee.Text.Generation do
outputs = predict_fun.(params, inputs)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
token_id = Nx.argmax(logits, axis: -1)

state = update_sequences(state, token_id, pad_token_id, eos_token_id)
Expand Down Expand Up @@ -632,14 +637,25 @@ defmodule Bumblebee.Text.Generation do
end

defnp batch_process_logits(logits_processor_fun, logits, state) do
logits
|> Nx.vectorize(:batch)
|> logits_processor_fun.(%{
sequence: Nx.vectorize(state.sequences, :batch),
length: state.length,
input_length: state.input_length
})
|> Nx.devectorize(keep_names: false)
logits = Nx.vectorize(logits, :batch)

{logits, new_context} =
logits_processor_fun.(logits, %{
sequence: Nx.vectorize(state.sequences, :batch),
length: state.length,
input_length: state.input_length,
# logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch)
logits_processor_states: state.logits_processor_states
})

logits = Nx.devectorize(logits, keep_names: false)

logits_processor_states =
Nx.devectorize(new_context.logits_processor_states, keep_names: false)

sequences = Nx.devectorize(new_context.sequence, keep_names: false)

{logits, %{state | sequences: sequences, logits_processor_states: logits_processor_states}}
end

# Contrastive search
Expand Down Expand Up @@ -684,7 +700,7 @@ defmodule Bumblebee.Text.Generation do
joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)

Expand Down Expand Up @@ -727,7 +743,7 @@ defmodule Bumblebee.Text.Generation do

logits = outputs.logits[[.., -1]]
logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)

scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)
Expand Down Expand Up @@ -888,7 +904,7 @@ defmodule Bumblebee.Text.Generation do
outputs = predict_fun.(params, inputs)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)

Expand Down
50 changes: 49 additions & 1 deletion test/bumblebee/text/generation/logits_processing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,53 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do

alias Bumblebee.Text.Generation.LogitsProcessing

describe "stateful logits processors" do
defmodule StatefulLogitsProcessing do
import Nx.Defn

deftransform stateful_processor(logits, context, opts) do
initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]])

suppressed_index =
context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index
context.logits_processor_state[:next_suppressed_index] || initial_suppressed_index


values =
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index))

logits = Nx.indexed_put(logits, suppressed_index, values)

next_suppressed_index = Nx.add(suppressed_index, Nx.tensor([1]))

context =
put_in(
context,
[:logits_processor_states, :next_suppressed_index],
next_suppressed_index
)

{logits, context}
end
end

test "can register and modify state" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])

context = context([1, 0, 0, 0])

{logits, context} =
StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_index: 0)

assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0]))
assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([1]))

{logits, context} =
StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_index: 0)

assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0]))
assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([2]))
end
end

describe "suppressed_tokens_processor/3" do
test "ignores the given tokens" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])
Expand Down Expand Up @@ -382,7 +429,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
%{
sequence: Nx.tensor(sequence),
length: Enum.count(sequence, &(&1 != 0)),
input_length: 1
input_length: 1,
logits_processor_states: %{}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logits_processor_states: %{}
logits_processor_state: %{}

The individual processor function see only a state, not all states of the batch

}
end
end
57 changes: 57 additions & 0 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,61 @@ defmodule Bumblebee.Text.GenerationTest do

assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]]))
end

test "with stateful logits processor" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec

inputs = %{
"input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]),
"attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]),
"seed" => Nx.tensor([0])
}

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)

generate =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors:
&Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2,
initial_suppressed_index: 0
)
)

%{token_ids: token_ids} = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[80, 80, 80]]))
end

defmodule StatefulLogitsProcessing do
import Nx.Defn

deftransform stateful_processor(logits, context, opts) do
initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]])

suppressed_index =
context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index

values =
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index))

logits = Nx.indexed_put(logits, suppressed_index, values)

next_suppressed_index = Nx.add(suppressed_index, Nx.tensor([1]))

context =
put_in(
context,
[:logits_processor_states, :next_suppressed_index],
next_suppressed_index
)

{logits, context}
end
end
end