diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 2b9547d1..4988bd07 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -638,14 +638,17 @@ defmodule Bumblebee.Text.Generation do end defnp batch_process_logits(logits_processor_fun, logits, state) do - logits - |> Nx.vectorize(:batch) - |> logits_processor_fun.(%{ - sequence: Nx.vectorize(state.sequences, :batch), - length: state.length, - input_length: state.input_length - }) - |> Nx.devectorize(keep_names: false) + case logits + |> Nx.vectorize(:batch) + |> logits_processor_fun.(%{ + sequence: Nx.vectorize(state.sequences, :batch), + length: state.length, + input_length: state.input_length, + last_state: state[:last] + }) do + {logits, new_state} -> {Nx.devectorize(logits, keep_names: false), new_state} + logits -> Nx.devectorize(logits, keep_names: false) + end end # Contrastive search @@ -894,7 +897,15 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + + {logits, new_state} = + case batch_process_logits(logits_processor_fun, logits, state) do + {logits, new_state} -> {logits, new_state} + logits -> {logits, state} + end + + state = Map.merge(state, new_state) + scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 3eb7ee10..fab6864a 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -11,6 +11,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do Enum.dedup_by(dfa.state_transitions, fn {state, _token_id, _next_state} -> state end) |> length() + state_transition_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)}) state_transitions_tensor = @@ -26,19 +27,20 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) - current_state = - find_current_state( - initial_state, - state_transitions_tensor, - context.sequence, - context.input_length, - context.length - ) + last_state = + if last_state = context[:last_state] do + Nx.tensor([last_state]) |> Nx.vectorize(batch: 1) + else + initial_state + end + + last_token_id = context.sequence[context.length] + current_state = state_transitions_tensor[[last_state, last_token_id]] |> Nx.squeeze() suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) logits = Nx.select(state_transitions_tensor[current_state], logits, suppressed_logits) - logits + {logits, current_state} end defn find_current_state( diff --git a/pair_programming.exs b/pair_programming.exs index 7a264eba..bc514686 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -41,7 +41,7 @@ states = [ :done ] -state_to_num = fn state -> Enum.find_index(states, & &1 == state) end +state_to_num = fn state -> Enum.find_index(states, &(&1 == state)) end # ------------------------------------- above chars ------------------------------ # # ------------------------------------- below tokens ------------------------------ # @@ -107,7 +107,7 @@ state_transitions = end end) -dfa = %{ state_transitions: state_transitions, } +dfa = %{state_transitions: state_transitions} generation_config = Bumblebee.configure(generation_config, @@ -123,46 +123,36 @@ serving = defn_options: [compiler: Nx.Defn.Evaluator] ) -%{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg +%{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg # IO.puts result.text run_benchmarks = fn -> -serving_fn = fn max_new_tokens, dfa -> - generation_config = - Bumblebee.configure(generation_config, - max_new_tokens: max_new_tokens, - strategy: %{type: :multinomial_sampling, top_p: 0.6}, - dfa: dfa - ) - - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: sequence_length], - stream: false, - defn_options: [compiler: Nx.Defn.Evaluator] - ) - - end - -serving_dfa_8 = serving_fn.(8, dfa) -serving_dfa_16 = serving_fn.(16, dfa) -serving_dfa_8_no_skip = serving_fn.(8, Map.delete(dfa, :ambiguous_token_ids)) -serving_dfa_16_no_skip = serving_fn.(16, Map.delete(dfa, :ambiguous_token_ids)) -serving_no_dfa_8 = serving_fn.(8, nil) -serving_no_dfa_16 = serving_fn.(16, nil) - -Benchee.run( - %{ - "max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8, prompt) end, - "max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16, prompt) end, - "no skip: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8_no_skip, prompt) end, - "no skip: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16_no_skip, prompt) end, - "no dfa: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_no_dfa_8, prompt) end, - "no dfa: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_no_dfa_16, prompt) end, - }, - time: 30, - memory_time: 2 -) + serving_fn = fn max_new_tokens, dfa -> + generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: max_new_tokens, + strategy: %{type: :multinomial_sampling, top_p: 0.6}, + dfa: dfa + ) + + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: sequence_length], + stream: false, + defn_options: [compiler: Nx.Defn.Evaluator] + ) + end + + # Benchee.run( + # %{ + # "max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8, prompt) end, + # "max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16, prompt) end, + # "no skip: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8_no_skip, prompt) end, + # "no skip: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16_no_skip, prompt) end, + # "no dfa: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_no_dfa_8, prompt) end, + # "no dfa: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_no_dfa_16, prompt) end, + # }, + # time: 30, + # memory_time: 2 + # ) end - -