From acbed1ee2c1f99d13cb7655e3fb4844cfd39abae Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 20 Oct 2025 17:14:04 +0200 Subject: [PATCH 1/3] [#SAMPLE-7] DFA constrained sampling https://bitcrowd.atlassian.net/browse/SAMPLE-7 From 63a298c4fc31fd433bab257e2904b53f057719e1 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 20 Oct 2025 17:21:39 +0200 Subject: [PATCH 2/3] Enable constrained generation with DFA --- lib/bumblebee/text/generation.ex | 6 + .../text/generation/logits_processing.ex | 143 ++++++++++++++++++ lib/bumblebee/text/generation_config.ex | 9 ++ 3 files changed, 158 insertions(+) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 51ccb58a..a84ac81f 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -369,6 +369,12 @@ defmodule Bumblebee.Text.Generation do if config.forced_token_ids do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) end, + if config.allowed_token_ids != [] do + &allowed_tokens_processor(&1, &2, allowed_token_ids: config.allowed_token_ids) + end, + if config.dfa do + &dfa_processor(&1, &2, dfa: config.dfa) + end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index eff38e52..39be399c 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -3,6 +3,133 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do import Nx.Defn + deftransform dfa_processor(logits, context, opts \\ []) do + opts = Keyword.validate!(opts, [:dfa]) + dfa = opts[:dfa] + dfa_mode = dfa[:mode] + + last_state = + Enum.map(dfa.state_transitions, fn {state, _token_id, next_state} -> + max(state, next_state) + end) + |> Enum.max() + + num_states = last_state + 1 + + state_transition_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)}) + + state_transitions_tensor = + for {current_state, token_id, next_state} <- dfa.state_transitions, + reduce: state_transition_tensor do + state_transition_tensor -> + Nx.indexed_put( + state_transition_tensor, + Nx.tensor([current_state, token_id]), + next_state + ) + end + + initial_state = Nx.tensor([dfa.initial_state]) |> Nx.vectorize(:batch) + + case dfa_mode do + :stateful -> + current_state = + if context.length == context.input_length do + initial_state + else + last_state = context.logits_processor_state.dfa + + current_state_from_last_state( + state_transitions_tensor, + context.sequence, + context.length, + last_state + ) + end + + logits = suppress_logits(logits, state_transitions_tensor, current_state) + + context = put_in(context, [:logits_processor_state, :dfa], current_state) + + {logits, context} + + :stateless -> + current_state = + if context.length == context.input_length do + initial_state + else + find_current_state( + initial_state, + state_transitions_tensor, + context.sequence, + context.input_length, + context.length + ) + end + + suppress_logits(logits, state_transitions_tensor, current_state) + end + end + + defnp suppress_logits(logits, state_transitions_tensor, state) do + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + Nx.select(state_transitions_tensor[state], logits, suppressed_logits) + end + + defnp current_state_from_last_state( + state_transitions_tensor, + sequence, + current_length, + last_state + ) do + last_token_id = sequence[current_length - 1] + state_transitions_tensor[[last_state, last_token_id]] |> Nx.squeeze() + end + + defn find_current_state( + initial_state, + state_transitions_tensor, + sequence, + input_length, + current_length + ) do + generated_length = current_length - input_length + + last_token_id = sequence[current_length - 1] + token_column = state_transitions_tensor[[.., last_token_id]] |> Nx.squeeze() + + # top_k gives two top values + indices of the column + # if the token is unambiguous, there is only one value != 0 in the column (that's top_values[0]) + # if top_values[1] != 0, there must be two values != 0 in the column, so it's ambiguous + {top_values, _top_indices} = Nx.top_k(token_column, k: 2) + + ambiguous_token? = top_values[[1]] + + state = + cond do + ambiguous_token? -> + {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = + while {state = initial_state, i = 0, sequence, input_length, generated_length, + state_transitions_tensor}, + Nx.less(i, generated_length) do + chosen_token = sequence[input_length + i] + new_state = state_transitions_tensor[[state, chosen_token]] + + {new_state, i + 1, sequence, input_length, generated_length, + state_transitions_tensor} + end + + state + + true -> + # we know that top_values[0] is the state we moved to + # as it's the only state transition with new state != 0 for the token_id + top_values[[0]] + end + + state + end + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) @@ -11,6 +138,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do Nx.indexed_put(logits, indices, values) end + deftransform allowed_tokens_processor(logits, _context, opts \\ []) do + _opts = Keyword.validate!(opts, [:allowed_token_ids]) + + allow_token_ids(logits, opts[:allowed_token_ids]) + end + defn bos_token_processor(logits, context, opts \\ []) do opts = keyword!(opts, [:bos_token_id]) bos_token_id = opts[:bos_token_id] @@ -113,6 +246,16 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do |> Nx.put_slice([token_id], Nx.tensor([0], type: Nx.type(logits))) end + deftransformp allow_token_ids(logits, allowed_token_ids) do + # Convert allowed_token_ids to a tensor if it's a list + allowed_indices = Nx.tensor(allowed_token_ids) + allowed_logits = Nx.take(logits, allowed_indices) + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + + indices = Nx.new_axis(allowed_indices, -1) + Nx.indexed_put(suppressed_logits, indices, allowed_logits) + end + deftransformp ignore_token_id(logits, token_id) do Nx.put_slice( logits, diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index d7a6a9a0..ad8ef26f 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -93,6 +93,15 @@ defmodule Bumblebee.Text.GenerationConfig do default: [], doc: "a list of token ids to suppress during generation" ], + allowed_token_ids: [ + default: [], + doc: + "a list of token ids to enforce during generation (suppressing the all tokens that are not in the list)" + ], + dfa: [ + default: nil, + doc: "the definition of a deterministic finite automaton (dfa) for the generation" + ], no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" From 6884e12af730d927c752f5508b11544a05ec98c1 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 20 Oct 2025 17:55:39 +0200 Subject: [PATCH 3/3] dev script --- pair_programming.exs | 205 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 pair_programming.exs diff --git a/pair_programming.exs b/pair_programming.exs new file mode 100644 index 00000000..1c10a1df --- /dev/null +++ b/pair_programming.exs @@ -0,0 +1,205 @@ +Mix.install([ + {:bumblebee, path: "../bumblebee_bitcrowd"}, + {:nx, "~> 0.10.0", override: true}, + {:exla, "~> 0.10.0"}, + {:emlx, github: "elixir-nx/emlx"}, + {:benchee, "~> 1.0"} +]) + +# backend = EMLX.Backend +# compiler = Nx.Defn.Evaluator +backend = EXLA.Backend +compiler = EXLA + +Nx.global_default_backend(backend) + +repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} + +sequence_length = 512 + +prompt = """ +Give me an array that contains a mix of numbers and text. +There MUST be at least one number and one text. +Valid examples are: + +["hello",89,"hola",6,4,8] +""" + +# this DFA definition is "array of integers" generatd by outlines-core +# +# let schema = r#"{ +# "type": "array", +# "items": { +# "type": "integer" +# } +# }"#; + +initial_state = 64 + +state_transitions = + [ + {96, 33, 128}, + {96, 40, 128}, + {96, 36, 128}, + {96, 32, 112}, + {96, 39, 128}, + {96, 35, 128}, + {96, 38, 128}, + {96, 34, 128}, + {96, 41, 128}, + {96, 37, 128}, + {144, 2, 144}, + {176, 77, 144}, + {224, 33, 240}, + {224, 40, 240}, + {224, 36, 240}, + {224, 32, 112}, + {224, 39, 240}, + {224, 35, 240}, + {224, 38, 240}, + {224, 34, 240}, + {224, 41, 240}, + {224, 37, 240}, + {128, 33, 128}, + {128, 77, 144}, + {128, 36, 128}, + {128, 28, 192}, + {128, 39, 128}, + {128, 10790, 224}, + {128, 34, 128}, + {128, 37, 128}, + {128, 40, 128}, + {128, 32, 128}, + {128, 216, 176}, + {128, 35, 128}, + {128, 38, 128}, + {128, 41, 128}, + {128, 6329, 144}, + {80, 33, 128}, + {80, 77, 144}, + {80, 29, 96}, + {80, 36, 128}, + {80, 41, 128}, + {80, 32, 112}, + {80, 39, 128}, + {80, 216, 176}, + {80, 35, 128}, + {80, 40, 128}, + {80, 38, 128}, + {80, 34, 128}, + {80, 6329, 144}, + {80, 37, 128}, + {112, 216, 176}, + {112, 10790, 224}, + {112, 77, 144}, + {112, 6329, 144}, + {112, 28, 192}, + {64, 9197, 96}, + {64, 75, 160}, + {208, 33, 240}, + {208, 29, 224}, + {208, 36, 240}, + {208, 40, 240}, + {208, 32, 112}, + {208, 39, 240}, + {208, 35, 240}, + {208, 38, 240}, + {208, 34, 240}, + {208, 41, 240}, + {208, 37, 240}, + {160, 33, 128}, + {160, 77, 144}, + {160, 36, 128}, + {160, 39, 128}, + {160, 256, 176}, + {160, 731, 96}, + {160, 34, 128}, + {160, 37, 128}, + {160, 29, 96}, + {160, 40, 128}, + {160, 32, 112}, + {160, 216, 80}, + {160, 35, 128}, + {160, 38, 128}, + {160, 6329, 144}, + {160, 41, 128}, + {240, 33, 240}, + {240, 77, 144}, + {240, 36, 240}, + {240, 28, 192}, + {240, 39, 240}, + {240, 10790, 224}, + {240, 34, 240}, + {240, 37, 240}, + {240, 40, 240}, + {240, 32, 240}, + {240, 216, 176}, + {240, 35, 240}, + {240, 38, 240}, + {240, 41, 240}, + {240, 6329, 144}, + {192, 33, 240}, + {192, 29, 224}, + {192, 36, 240}, + {192, 40, 240}, + {192, 32, 112}, + {192, 39, 240}, + {192, 216, 208}, + {192, 35, 240}, + {192, 731, 224}, + {192, 38, 240}, + {192, 34, 240}, + {192, 41, 240}, + {192, 37, 240} + ] + +unique_states = + Enum.flat_map(state_transitions, fn {state, _token_id, next_state} -> [state, next_state] end) + |> Enum.uniq() + |> Enum.sort() + +states_map = for {state, i} <- Enum.with_index(unique_states), into: %{}, do: {state, i} + +compact_states = + Enum.map(state_transitions, fn {state, token_id, next_state} -> + {states_map[state], token_id, states_map[next_state]} + end) + +state_transitions = compact_states +initial_state = states_map[initial_state] + +dfa = %{state_transitions: state_transitions, mode: :stateful, initial_state: initial_state} + +build_serving = fn backend, compiler, max_new_tokens, dfa -> + Nx.global_default_backend(backend) + + {:ok, model_info} = Bumblebee.load_model(repo, backend: backend) + + {:ok, tokenizer} = Bumblebee.load_tokenizer(repo) + {:ok, generation_config} = Bumblebee.load_generation_config(repo) + + generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: max_new_tokens, + # min_length: sequence_length + max_new_tokens, + strategy: %{type: :multinomial_sampling, top_p: 0.9}, + # strategy: %{type: :greedy_search}, + dfa: dfa + ) + + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: sequence_length], + stream: false, + defn_options: [compiler: compiler] + ) +end + +max_new_tokens = 32 +# dfa = nil +dfa = %{dfa | mode: :stateful} + +serving = build_serving.(backend, compiler, max_new_tokens, dfa) + +for _i <- 1..3 do + %{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg +end