Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ defmodule Bumblebee.Text.Generation do
end,
if config.temperature && config.temperature != 1.0 do
&temperature_processor(&1, &2, temperature: config.temperature)
end,
if config.grammar do
&grammar_constrained_processor(&1, &2, grammar: config.grammar)
end
] ++
if config.strategy.type == :multinomial_sampling do
Expand Down
94 changes: 94 additions & 0 deletions lib/bumblebee/text/generation/grammar_constraint.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
defmodule Bumblebee.Text.Generation.GrammarConstraint do
@moduledoc false

alias Bumblebee.Text.Generation.TokenTrie
alias Bumblebee.Text.Generation.Stack
alias EBNF.ParseState

alias __MODULE__

# Models a constraint

defstruct [
:token_trie,
:grammar_encoding,
:tokenizer,
:start_rule_id,
:start_rule_position,
:rule_positions
]

def create(grammar, root, tokenizer) do
%ParseState{symbol_ids: symbols, grammar_encoding: encoding} = EBNF.encode(grammar)
trie = TokenTrie.create(tokenizer)
start_rule_id = Map.fetch!(symbols, root)
rule_positions = get_rule_positions(encoding)

%GrammarConstraint{
token_trie: trie,
grammar_encoding: encoding,
tokenizer: tokenizer,
start_rule_id: start_rule_id,
start_rule_position: Map.fetch!(rule_positions, start_rule_id),
rule_positions: rule_positions
}
end

def init_stacks(constraint) do
# stack will never exceed the grammar encoding size
stack =
Stack.new(length(constraint.grammar_encoding))
|> Stack.push(constraint.start_rule_pos + 2)
|> advance_stack()
end

defn advance_stack(stack) do
if Nx.equal(Stack.length(stack), 0) do
stack
else
top = Stack.peek(stack)

if Nx.equal(top, 2) do
stack
else


end
end
end

defp get_rule_positions(grammar_encoding) do
recur_get_rule_positions(grammar_encoding, 0, %{})
end

defp recur_get_rule_positions([0xFFFF], _pos, rule_positions), do: rule_positions

defp recur_get_rule_positions([rule_id | rest], pos, rule_positions) do
rule_positions = Map.put(rule_positions, rule_id, pos)

case find_next_rule(rest, pos + 1) do
{[_ | leftover], pos} ->
recur_get_rule_positions(leftover, pos + 1, rule_positions)

{[], _} ->
rule_positions
end
end

defp find_next_rule([0 | rest], pos) do
{rest, pos + 1}
end

defp find_next_rule([rule_size | _] = leftover, pos) do
leftover = Enum.drop(leftover, rule_size + 1)
pos = pos + rule_size + 1

case leftover do
[0 | _] ->
{leftover, pos}

leftover ->
find_next_rule(leftover, pos)
end
end
end
12 changes: 12 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do

import Nx.Defn

alias Bumblebee.Text.Generation.GrammarConstraint

deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do
opts = Keyword.validate!(opts, [:suppressed_token_ids])

Expand Down Expand Up @@ -255,4 +257,14 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
{idx, _token_id} -> idx + 1
end
end

deftransform grammar_constrained_processor(logits, input_ids, opts \\ []) do
opts = Keyword.validate!(opts, [:grammar, :tokenizer])

grammar = opts[:grammar]
tokenizer = opts[:tokenizer]

constraint = GrammarConstraint.create(grammar, "root", tokenizer)
batch_stacks = GrammarConstraint.init_stacks(constraint)
end
end
64 changes: 64 additions & 0 deletions lib/bumblebee/text/generation/stack.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
defmodule Bumblebee.Text.Generation.Stack do
@moduledoc false

# A "stack" like data structure represented as an Nx container
# to make constrained sampling possible/easier. The HF implementation
# uses a "dynamic" stack, but we need all shapes up front and
# can't manipulate so we use a "stack" and then a pointer in
# the stack

alias __MODULE__

@derive {Nx.Container, containers: [:data, :pointer]}
defstruct [:data, :pointer]

import Nx.Defn

@empty_value -1

@doc """
Initializes a new stack.
"""
def new(size, opts \\ []) do
opts = Keyword.validate!(opts, type: :s64)

%Stack{
data: Nx.broadcast(Nx.tensor(@empty_value, type: opts[:type]), {size}),
pointer: Nx.tensor(0)
}
end

@doc """
Push a value to the top of the stack.
"""
deftransform push(%Stack{data: data, pointer: pointer} = stack, value) do
unless Nx.rank(value) == 0, do: raise("can only push scalar values to stack")

%{
data: Nx.put_slice(data, [pointer], value),
pointer: Nx.add(pointer, 1)
}
end

@doc """
Pops a value from the stack.
"""
defn pop(%Stack{data: data, pointer: pointer} = stack) do
value = data[[pointer]]
{value, %{stack | pointer: Nx.subtract(pointer, 1)}}
end

@doc """
Peeks at the head of the stack.
"""
defn peek(%Stack{data: data, pointer: pointer}) do
data[[pointer]]
end

@doc """
Returns the length of the stack.
"""
defn length(%Stack{pointer: pointer}) do
pointer
end
end
61 changes: 61 additions & 0 deletions lib/bumblebee/text/generation/token_trie.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
defmodule Bumblebee.Text.Generation.TokenTrie do
@moduledoc false

# Internal data structure used in constrained sampling

alias Bumblebee.Text.PreTrainedTokenizer
alias __MODULE__

defstruct [:tokens, :trie, :eos_token_id]

@leaf -1

@doc """
Returns the token encoded by the given ID.
"""
def id_to_token(%TokenTrie{tokens: tokens}, id) do
Map.fetch!(tokens, id)
end

@doc """
Returns the number of tokens in the trie.
"""
def n_tokens(%TokenTrie{tokens: tokens}) do
length(tokens)
end

@doc """
Creates a trie from the vocabulary in the given tokenizer.
"""
def create(%PreTrainedTokenizer{native_tokenizer: tokenizer, special_tokens: %{eos: eos_token}}) do
vocab = Tokenizers.Tokenizer.get_vocab(tokenizer)
eos_token_id = Map.fetch!(vocab, eos_token)

tokens =
Map.new(vocab, fn {token, id} ->
# TODO: Special cases for GPT2 and Llama
{id, String.to_charlist(token)}
end)

trie =
Enum.reduce(tokens, %{}, fn {token_id, token_bytes}, acc ->
insert_into_trie(acc, token_bytes, token_id)
end)

%TokenTrie{tokens: tokens, trie: trie, eos_token_id: eos_token_id}
end

## Helpers

defp insert_into_trie(trie, token_bytes, token_id) do
do_insert_into_trie(trie, token_bytes, token_id)
end

defp do_insert_into_trie(trie, [], token_id), do: Map.put(trie, @leaf, token_id)

defp do_insert_into_trie(trie, [byte | rest_bytes], token_id) do
current = Map.get(trie, byte, %{})
updated = do_insert_into_trie(current, rest_bytes, token_id)
Map.put(trie, byte, updated)
end
end
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ defmodule Bumblebee.MixProject do
{:stb_image, "~> 0.6.0", only: :test},
{:bypass, "~> 2.1", only: :test},
{:ex_doc, "~> 0.28", only: :dev, runtime: false},
{:nx_signal, "~> 0.2.0"}
{:nx_signal, "~> 0.2.0"},
{:ebnf, github: "seanmor5/ebnf"}
]
end

Expand Down
1 change: 1 addition & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"cowlib": {:hex, :cowlib, "2.11.0", "0b9ff9c346629256c42ebe1eeb769a83c6cb771a6ee5960bd110ab0b9b872063", [:make, :rebar3], [], "hexpm", "2b3e9da0b21c4565751a6d4901c20d1b4cc25cbb7fd50d91d2ab6dd287bc86a9"},
"decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"ebnf": {:git, "https://github.com/seanmor5/ebnf.git", "a69e84619881b27fa8eceff29713ff9b496814cd", []},
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
"ex_doc": {:hex, :ex_doc, "0.31.0", "06eb1dfd787445d9cab9a45088405593dd3bb7fe99e097eaa71f37ba80c7a676", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5350cafa6b7f77bdd107aa2199fe277acf29d739aba5aee7e865fc680c62a110"},
"exla": {:hex, :exla, "0.7.0", "27fac40a580f0d3816fe3bf35c50dfc2f99597d26ac7e2aca4a3c62b89bb427f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d3bfc622deb52cec95efc9d76063891afc7cd33e38eddbb01f3385c53e043c40"},
Expand Down