diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 29162dbd..5aedd06c 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -137,6 +137,11 @@ defmodule Bumblebee do "GemmaModel" => {Bumblebee.Text.Gemma, :base}, "GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling}, "GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification}, + "Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling}, + "Gemma3TextModel" => {Bumblebee.Text.Gemma3Text, :base}, + "Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling}, + "Gemma3TextForSequenceClassification" => + {Bumblebee.Text.Gemma3Text, :for_sequence_classification}, "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification}, "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification}, "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling}, @@ -252,6 +257,7 @@ defmodule Bumblebee do "camembert" => :camembert, "clip" => :clip, "gemma" => :gemma, + "gemma3_text" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, "gpt_bigcode" => :gpt2, diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index b69e3115..188b0ffe 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do - a keyword list (applied to all blocks) - a function that takes the block index and returns the configuration + * `:attention_window_size` - sliding window attention configuration. Can be: + - `nil` for global attention (default) + - a `{left, right}` tuple (applied to all blocks) + - a function that takes the block index and returns `nil` or `{left, right}`. + This enables per-layer attention patterns like Gemma 3's alternating + local/global attention (5 local layers followed by 1 global layer) + * `:name` - the prefix for layer names For all other options (including required options) see `block/2`. @@ -36,6 +43,8 @@ defmodule Bumblebee.Layers.Transformer do def blocks(hidden_state, opts) do validate_required_keys!(opts, [:num_blocks, :num_attention_heads, :hidden_size, :ffn]) + # Note: :attention_window_size is NOT in block_opts_keys because it's handled + # specially (supports per-layer function) and passed explicitly to block/2 block_opts_keys = [ :num_attention_heads, :num_key_value_heads, @@ -52,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do :output_use_bias, :layer_norm, :block_type, - :attention_window_size, :attention_scale, :query_norm, :key_norm @@ -66,6 +74,7 @@ defmodule Bumblebee.Layers.Transformer do :name, :num_blocks, :rotary_embedding, + :attention_window_size, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -87,6 +96,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_head_mask = opts[:cross_attention_head_mask] cache = opts[:cache] rotary_embedding = opts[:rotary_embedding] + attention_window_size = opts[:attention_window_size] block_opts = Keyword.take(opts, block_opts_keys) @@ -123,6 +133,15 @@ defmodule Bumblebee.Layers.Transformer do config when is_list(config) -> config end + # Support per-layer attention window size for models like Gemma 3 + # that alternate between local (sliding window) and global attention + block_attention_window_size = + case attention_window_size do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + size -> size + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, @@ -136,6 +155,7 @@ defmodule Bumblebee.Layers.Transformer do block_cache: block_cache, offset: offset, rotary_embedding: block_rotary_embedding, + attention_window_size: block_attention_window_size, name: join(name, idx) ] ++ block_opts ) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex new file mode 100644 index 00000000..3322ab5f --- /dev/null +++ b/lib/bumblebee/text/gemma3_text.ex @@ -0,0 +1,646 @@ +defmodule Bumblebee.Text.Gemma3Text do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 262_208, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 131_072, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 2304, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 9216, + doc: "the dimensionality of intermediate layers" + ], + attention_head_size: [ + default: 256, + doc: "the size of the key, value, and query projection per attention head" + ], + attention_scale_base: [ + default: 256, + doc: """ + base value for computing attention scale. The attention scale is computed as + `attention_scale_base ** -0.5`. + """ + ], + num_blocks: [ + default: 26, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 8, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 4, + doc: "the number of key value heads for each attention layer in the model" + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 1_000_000, + doc: "base for computing rotary embedding frequency for global attention layers" + ], + rotary_embedding_base_local: [ + default: 10_000, + doc: "base for computing rotary embedding frequency for local (sliding) attention layers" + ], + rotary_embedding_scaling_strategy: [ + default: nil, + doc: """ + scaling configuration for rotary embedding. Currently the supported values are: + + * `%{type: :linear, factor: number()}` + + * `%{type: :dynamic, factor: number()}` + + For more details see https://www.reddit.com/r/LocalLlama/comments/14mrgpr/dynamically_scaled_rope_further_increases + """ + ], + use_attention_bias: [ + default: false, + doc: + "whether or not to use bias in the query, key, value, and output projections in attention layers" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + sliding_window: [ + default: 4096, + doc: "the sliding window size for local attention layers" + ], + layer_types: [ + default: nil, + doc: """ + a list of layer types for each layer, where each element is either `:sliding_attention` + (local attention with sliding window) or `:full_attention` (global attention) + """ + ], + tie_word_embeddings: [ + default: true, + doc: "whether to tie input and output embedding weights" + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Gemma 3 model family. + + Gemma 3 is an updated version of the Gemma architecture with several key improvements: + + * Alternating local/global attention (5:1 ratio by default) for better efficiency + * Larger vocabulary (262K tokens) + * Extended context length (up to 128K tokens) + + This module also supports FunctionGemma, which is built on Gemma 3 and optimized + for function calling tasks. + + ## Architectures + + * `:base` - plain Gemma 3 without any head on top + + * `:for_causal_language_modeling` - Gemma 3 with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - Gemma 3 with a sequence + classification head. The head returns logits corresponding to + possible classes + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + shift: 1.0, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + # Note: Gemma 3 still normalizes embeddings by sqrt(hidden_size), same as Gemma v1 + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + |> Axon.nx(fn x -> + normalization_factor = + spec.hidden_size + |> Nx.tensor(type: Nx.type(x)) + |> Nx.sqrt() + + Nx.multiply(x, normalization_factor) + end) + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + # QK-norm functions for Gemma 3 (uses shift: 1.0 for (1+weight) formula) + query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + + # Per-layer attention window size based on layer_types + # :sliding_attention uses local (sliding window) attention + # :full_attention uses global attention (nil window size) + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + + attention_window_size = fn idx -> + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> nil + :sliding_attention -> {spec.sliding_window, spec.sliding_window} + end + end + + # Per-layer rotary embedding base: local layers use rotary_embedding_base_local, + # global layers use rotary_embedding_base + rotary_embedding = fn idx -> + base = + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.rotary_embedding_base + :sliding_attention -> spec.rotary_embedding_base_local + end + + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ] + end + + attention_scale = :math.pow(spec.attention_scale_base, -0.5) + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + attention_scale: attention_scale, + kernel_initializer: kernel_initializer(spec), + layer_norm: + &Layers.rms_norm(&1, + shift: 1.0, + name: &2, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: &gemma3_block_impl(&1, &2, &3, spec), + causal: true, + rotary_embedding: rotary_embedding, + attention_window_size: attention_window_size, + query_norm: query_norm, + key_norm: key_norm, + query_use_bias: spec.use_attention_bias, + key_use_bias: spec.use_attention_bias, + value_use_bias: spec.use_attention_bias, + output_use_bias: spec.use_attention_bias, + name: join(name, "blocks") + ) + end + + # Custom block implementation for Gemma 3's unique normalization structure: + # - Post-attention norm BEFORE residual add + # - Pre/post FFN norms + defp gemma3_block_impl(hidden_state, steps, name, spec) do + # Pre-attention norm + attention (using provided steps) + shortcut = hidden_state + + {hidden_state, attention_info} = + hidden_state + |> steps.self_attention_norm.() + |> steps.self_attention.() + + # Post-attention norm BEFORE residual (Gemma 3 specific) + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_attention_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + # FFN with pre/post norms (Gemma 3 specific) + shortcut = hidden_state + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "pre_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = steps.ffn.(hidden_state) + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + # Handle cross-attention (required by block interface but not used by Gemma 3) + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _ -> + raise "cross attention not supported" + end) + + {hidden_state, attention_info, cross_attention_info} + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + # Generate layer_types from sliding_window_pattern (default 6) + # Pattern: every Nth layer uses full attention, others use sliding attention + defp generate_layer_types(num_blocks) do + sliding_window_pattern = 6 + + Enum.map(0..(num_blocks - 1), fn i -> + if rem(i + 1, sliding_window_pattern) == 0 do + :full_attention + else + :sliding_attention + end + end) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + scaling_strategy_converter = fn name, value -> + case value do + %{"type" => "linear", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :linear, factor: factor}} + + %{"type" => "dynamic", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :dynamic, factor: factor}} + + _other -> + {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} + end + end + + # Support sliding_window_pattern for backward compatibility + # see https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/gemma3/configuration_gemma3.py#L188-L195 + data = + Map.put_new_lazy(data, "layer_types", fn -> + pattern = data["sliding_window_pattern"] || 6 + num_blocks = data["num_hidden_layers"] || 26 + + Enum.map(0..(num_blocks - 1), fn i -> + if rem(i + 1, pattern) == 0 do + "full_attention" + else + "sliding_attention" + end + end) + end) + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", number()}, + attention_scale_base: {"query_pre_attn_scalar", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_activation", activation()}, + use_attention_bias: {"attention_bias", boolean()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_base_local: {"rope_local_base_freq", number()}, + rotary_embedding_scaling_strategy: + {"rope_scaling", optional(scaling_strategy_converter)}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + sliding_window: {"sliding_window", optional(number())}, + layer_types: + {"layer_types", + list( + mapping(%{ + "sliding_attention" => :sliding_attention, + "full_attention" => :full_attention + }) + )}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + # Gemma 3 specific params mapping with QK-norm and extra FFN layer norms + %{ + "embedder.token_embedding" => "model.embed_tokens", + # Attention projections + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + # QK-norm (Gemma 3 specific) - uses query_norm/key_norm from shared infrastructure + "decoder.blocks.{n}.self_attention.query_norm" => "model.layers.{n}.self_attn.q_norm", + "decoder.blocks.{n}.self_attention.key_norm" => "model.layers.{n}.self_attn.k_norm", + # Layer norms + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.post_attention_norm" => "model.layers.{n}.post_attention_layernorm", + # FFN layer norms (Gemma 3 specific) + "decoder.blocks.{n}.pre_ffn_norm" => "model.layers.{n}.pre_feedforward_layernorm", + "decoder.blocks.{n}.post_ffn_norm" => "model.layers.{n}.post_feedforward_layernorm", + # FFN projections + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + # Output + "output_norm" => "model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), + "sequence_classification_head.output" => "score" + } + end + end +end diff --git a/mix.exs b/mix.exs index 44210c13..a768a59f 100644 --- a/mix.exs +++ b/mix.exs @@ -64,6 +64,7 @@ defmodule Bumblebee.MixProject do "notebooks/llms.livemd", "notebooks/llms_rag.livemd", "notebooks/qwen3.livemd", + "notebooks/function_calling.livemd", "notebooks/fine_tuning.livemd", "examples/phoenix/README.md" ], diff --git a/notebooks/function_calling.livemd b/notebooks/function_calling.livemd new file mode 100644 index 00000000..d9752bc9 --- /dev/null +++ b/notebooks/function_calling.livemd @@ -0,0 +1,523 @@ +# Function calling with FunctionGemma + +```elixir +Mix.install([ + {:bumblebee, "~> 0.6.0"}, + {:nx, "~> 0.9.0"}, + {:exla, "~> 0.9.0"}, + {:kino, "~> 0.14.0"} +]) + +Nx.global_default_backend(EXLA.Backend) +``` + +## Why FunctionGemma? + +[FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) is a compact 270M parameter model from Google, specifically designed for function calling tasks. + +## Loading the Model + +FunctionGemma requires accepting Google's license on HuggingFace. Visit [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it) to request access, then create a [HuggingFace auth token](https://huggingface.co/settings/tokens) and add it as a `HF_TOKEN` Livebook secret. + +```elixir +hf_token = System.fetch_env!("LB_HF_TOKEN") +repo = {:hf, "google/functiongemma-270m-it", auth_token: hf_token} + +{:ok, model_info} = Bumblebee.load_model(repo) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +:ok +``` + +## Creating the Serving + +```elixir +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 512], + defn_options: [compiler: EXLA] + ) + +Kino.start_child({Nx.Serving, name: FunctionGemma, serving: serving}) +``` + +## Function Schema Builder + +FunctionGemma uses a specific prompt format. Here's a complete module to build function declarations: + +```elixir +defmodule FunctionGemma.Schema do + @moduledoc """ + Builds FunctionGemma-compatible function declarations. + + ## Example + + FunctionGemma.Schema.declare("get_weather", "Get current weather", [ + location: [type: :string, description: "City name", required: true], + units: [type: :string, description: "celsius or fahrenheit"] + ]) + """ + + @type param_opts :: [ + type: :string | :number | :boolean | :array, + description: String.t(), + required: boolean() + ] + + @doc """ + Declares a function with its name, description, and parameters. + + ## Parameters + + - `name` - The function name (e.g., "get_weather") + - `description` - What the function does + - `parameters` - Keyword list of `{param_name, options}` + + ## Parameter Options + + - `:type` - One of `:string`, `:number`, `:boolean`, `:array` (default: `:string`) + - `:description` - Description of the parameter + - `:required` - Whether the parameter is required (default: `false`) + """ + @spec declare(String.t(), String.t(), keyword(param_opts())) :: String.t() + def declare(name, description, parameters \\ []) do + params_schema = build_parameters_schema(parameters) + + "" <> + "declaration:#{name}{" <> + "description:#{description}," <> + "parameters:#{params_schema}" <> + "}" + end + + @doc """ + Builds a complete prompt with system message, functions, and user query. + """ + @spec build_prompt(String.t(), [String.t()], String.t()) :: String.t() + def build_prompt(system_message, function_declarations, user_message) do + functions = Enum.join(function_declarations, "") + + """ + developer + #{system_message} + #{functions} + user + #{user_message} + model + """ + end + + # Private helpers + + defp build_parameters_schema(parameters) do + properties = build_properties(parameters) + required = build_required(parameters) + + "{properties:{#{properties}},required:[#{required}],type:OBJECT}" + end + + defp build_properties(parameters) do + parameters + |> Enum.map(fn {name, opts} -> + type = opts |> Keyword.get(:type, :string) |> type_to_string() + desc = Keyword.get(opts, :description) + + prop = + if desc do + "#{name}:{description:#{desc},type:#{type}}" + else + "#{name}:{type:#{type}}" + end + + prop + end) + |> Enum.join(",") + end + + defp build_required(parameters) do + parameters + |> Enum.filter(fn {_, opts} -> Keyword.get(opts, :required, false) end) + |> Enum.map(fn {name, _} -> "#{name}" end) + |> Enum.join(",") + end + + defp type_to_string(:string), do: "STRING" + defp type_to_string(:number), do: "NUMBER" + defp type_to_string(:boolean), do: "BOOLEAN" + defp type_to_string(:array), do: "ARRAY" + defp type_to_string(other), do: String.upcase(to_string(other)) +end +``` + +## Function Call Parser + +Parse the model's function call output into structured data: + +```elixir +defmodule FunctionGemma.Parser do + @moduledoc """ + Parses FunctionGemma function call responses. + """ + + @type function_call :: %{ + function: String.t(), + arguments: map() + } + + @doc """ + Parses a FunctionGemma response into a function call struct. + + ## Examples + + iex> parse("call:get_weather{location:Paris}") + {:ok, %{function: "get_weather", arguments: %{"location" => "Paris"}}} + + iex> parse("I don't know") + {:error, :no_function_call} + """ + @spec parse(String.t()) :: {:ok, function_call()} | {:error, atom()} + def parse(response) do + pattern = ~r/call:(\w+)\{(.*?)\}/ + + case Regex.run(pattern, response) do + [_, function_name, args_str] -> + arguments = parse_arguments(args_str) + {:ok, %{function: function_name, arguments: arguments}} + + nil -> + {:error, :no_function_call} + end + end + + @doc """ + Same as `parse/1` but raises on error. + """ + @spec parse!(String.t()) :: function_call() + def parse!(response) do + case parse(response) do + {:ok, result} -> result + {:error, reason} -> raise "Failed to parse function call: #{reason}" + end + end + + # Parse key:value pairs + defp parse_arguments(""), do: %{} + + defp parse_arguments(args_str) do + ~r/(\w+):([^<]*)/ + |> Regex.scan(args_str) + |> Enum.map(fn [_, key, value] -> {key, value} end) + |> Map.new() + end +end +``` + +## Mock Functions (Smart Home Example) + +Let's create actual mock functions that simulate a smart home system: + +```elixir +defmodule SmartHome do + @moduledoc """ + Mock smart home functions that FunctionGemma can call. + """ + + # Simulated device states + use Agent + + def start_link do + Agent.start_link( + fn -> + %{ + lights: %{ + "living room" => false, + "bedroom" => false, + "kitchen" => false + }, + thermostat: 20, + weather_cache: %{} + } + end, + name: __MODULE__ + ) + end + + @doc """ + Controls a light in a specific room. + + ## Parameters + - room: The room name (living room, bedroom, kitchen) + - action: "on" or "off" + """ + def control_light(%{"room" => room, "action" => action}) do + room = String.downcase(room) + state = action == "on" + + Agent.update(__MODULE__, fn data -> + put_in(data, [:lights, room], state) + end) + + current = Agent.get(__MODULE__, & &1.lights) + + %{ + success: true, + message: "Turned #{action} the #{room} light", + current_states: current + } + end + + def control_light(_), do: %{success: false, message: "Missing room or action parameter"} + + @doc """ + Gets the current weather for a location (mocked with random data). + + ## Parameters + - location: The city name + """ + def get_weather(%{"location" => location}) do + # Simulate weather data + conditions = ["sunny", "cloudy", "rainy", "partly cloudy", "windy"] + temp = Enum.random(15..30) + humidity = Enum.random(40..80) + condition = Enum.random(conditions) + + %{ + success: true, + location: location, + temperature: temp, + humidity: humidity, + condition: condition, + message: "Weather in #{location}: #{temp}C, #{condition}, #{humidity}% humidity" + } + end + + def get_weather(_), do: %{success: false, message: "Missing location parameter"} + + @doc """ + Sets the thermostat temperature. + + ## Parameters + - temperature: Temperature in Celsius (number as string) + """ + def set_thermostat(%{"temperature" => temp_str}) do + temp = + case Integer.parse(temp_str) do + {t, _} -> t + :error -> 20 + end + + Agent.update(__MODULE__, fn data -> + Map.put(data, :thermostat, temp) + end) + + %{ + success: true, + message: "Thermostat set to #{temp}C", + temperature: temp + } + end + + def set_thermostat(_), do: %{success: false, message: "Missing temperature parameter"} + + @doc """ + Returns current state of all devices. + """ + def get_status do + Agent.get(__MODULE__, & &1) + end +end + +# Start the mock smart home +SmartHome.start_link() +IO.puts("Smart Home system initialized!") +IO.inspect(SmartHome.get_status(), label: "Initial state") +``` + +## Function Executor + +Now let's create an executor that connects FunctionGemma to our mock functions: + +```elixir +defmodule FunctionGemma.Executor do + @moduledoc """ + Executes function calls from FunctionGemma using registered handlers. + """ + + @doc """ + Executes a parsed function call against registered handlers. + """ + def execute(%{function: function, arguments: args}, handlers) do + case Map.get(handlers, function) do + nil -> + {:error, "Unknown function: #{function}"} + + handler when is_function(handler, 1) -> + result = handler.(args) + {:ok, result} + end + end + + @doc """ + Complete pipeline: send prompt to model, parse response, execute function. + """ + def run(serving_name, prompt, handlers) do + # Get model response + %{results: [%{text: response}]} = Nx.Serving.batched_run(serving_name, prompt) + + IO.puts("Model response: #{response}") + + # Parse function call + case FunctionGemma.Parser.parse(response) do + {:ok, function_call} -> + IO.puts("Parsed: #{function_call.function}(#{inspect(function_call.arguments)})") + + # Execute function + case execute(function_call, handlers) do + {:ok, result} -> + {:ok, function_call, result} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end +end +``` + +## Putting It All Together + +Let's define our function schema and handlers, then run the complete pipeline: + +```elixir +# Define function declarations for the model +function_declarations = [ + FunctionGemma.Schema.declare( + "control_light", + "Turn a light on or off in a specific room", + room: [type: :string, description: "The room name (living room, bedroom, kitchen)", required: true], + action: [type: :string, description: "on or off", required: true] + ), + FunctionGemma.Schema.declare( + "get_weather", + "Get the current weather for a location", + location: [type: :string, description: "The city name", required: true] + ), + FunctionGemma.Schema.declare( + "set_thermostat", + "Set the thermostat temperature", + temperature: [type: :number, description: "Temperature in Celsius", required: true] + ) +] + +# Map function names to their implementations +function_handlers = %{ + "control_light" => &SmartHome.control_light/1, + "get_weather" => &SmartHome.get_weather/1, + "set_thermostat" => &SmartHome.set_thermostat/1 +} + +IO.puts("Registered #{length(function_declarations)} functions") +:ok +``` + +## Interactive Demo + +Try sending commands to the smart home assistant: + +```elixir +user_input = Kino.Input.textarea("Command", + default: "Turn on the lights in the living room" +) +``` + +```elixir +user_message = Kino.Input.read(user_input) + +prompt = + FunctionGemma.Schema.build_prompt( + "You are a smart home assistant that controls devices and provides information.", + function_declarations, + user_message + ) + +IO.puts("=== Sending to FunctionGemma ===") +IO.puts("User: #{user_message}\n") + +case FunctionGemma.Executor.run(FunctionGemma, prompt, function_handlers) do + {:ok, function_call, result} -> + IO.puts("\n=== Function Executed ===") + IO.puts("Function: #{function_call.function}") + IO.puts("Arguments: #{inspect(function_call.arguments)}") + IO.puts("\n=== Result ===") + IO.inspect(result, pretty: true) + + {:error, reason} -> + IO.puts("Error: #{inspect(reason)}") +end +``` + +## Batch Demo - Multiple Commands + +Watch the smart home respond to multiple commands: + +```elixir +commands = [ + "What's the weather in Tokyo?", + "Turn on the bedroom lights", + "Set the temperature to 22 degrees", + "Turn off the kitchen light" +] + +Kino.Shorts.data_table( + for command <- commands do + prompt = + FunctionGemma.Schema.build_prompt( + "You are a smart home assistant.", + function_declarations, + command + ) + + result = + case FunctionGemma.Executor.run(FunctionGemma, prompt, function_handlers) do + {:ok, fc, res} -> + %{ + command: command, + function: fc.function, + args: inspect(fc.arguments), + result: res.message + } + + {:error, reason} -> + %{command: command, function: "ERROR", args: "", result: inspect(reason)} + end + + IO.puts("---") + result + end +) +``` + +## Check Final Smart Home State + +```elixir +IO.puts("=== Final Smart Home State ===") +SmartHome.get_status() |> IO.inspect(pretty: true) +``` + +## Next Steps + +- **Fine-tune** on your specific function schemas for better accuracy +- **Add function responses** for multi-turn conversations +- **Integrate** with your actual APIs and services +- **Deploy** as a Phoenix LiveView application + +## Fine-tuning FunctionGemma + +Want to fine-tune FunctionGemma on your own function schemas? Check out these resources: + +- [FunctionGemma Fine-tuning Notebook (Colab)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/FunctionGemma_(270M)-Mobile-Actions.ipynb) - Step-by-step guide using Unsloth for efficient fine-tuning on Google Colab T4 +- [Google's FunctionGemma documentation](https://huggingface.co/google/functiongemma-270m-it) - Official model card and usage instructions diff --git a/test/bumblebee/text/gemma3_text_test.exs b/test/bumblebee/text/gemma3_text_test.exs new file mode 100644 index 00000000..6af3ce6a --- /dev/null +++ b/test/bumblebee/text/gemma3_text_test.exs @@ -0,0 +1,79 @@ +defmodule Bumblebee.Text.Gemma3TextTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Gemma3TextModel"}) + + assert %Bumblebee.Text.Gemma3Text{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-0.2461, 1.2074, 0.7663], [0.0675, 0.3987, 1.6659], [-0.3021, 0.8062, 1.0309]] + ]), + atol: 1.0e-4 + ) + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-Gemma3TextForSequenceClassification"} + ) + + assert %Bumblebee.Text.Gemma3Text{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[-0.0145, 0.1376]]), + atol: 1.0e-4 + ) + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Gemma3ForCausalLM"}) + + assert %Bumblebee.Text.Gemma3Text{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [[-0.0488, 0.0432, -0.0531], [-0.1553, -0.0812, 0.1153], [-0.0272, 0.1216, 0.0129]] + ]), + atol: 1.0e-4 + ) + end +end