diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index a191f5bf..b92eef42 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -607,7 +607,8 @@ defmodule Bumblebee do :params_filename, :log_params_diff, :backend, - :type + :type, + :preserve_source_types ]) with {:ok, repo_files} <- get_repo_files(repository), @@ -654,7 +655,7 @@ defmodule Bumblebee do [ params_mapping: params_mapping, loader_fun: loader_fun - ] ++ Keyword.take(opts, [:backend, :log_params_diff]) + ] ++ Keyword.take(opts, [:backend, :log_params_diff, :preserve_source_types]) params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts) {:ok, params} diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index c17fef85..858a5fc8 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -28,6 +28,11 @@ defmodule Bumblebee.Conversion.PyTorchParams do and loads the params file. Defaults to `Bumblebee.Conversion.PyTorchLoader.load!/1` + * `:preserve_source_types` - when `true`, preserves FP8 types from the + source file instead of converting them to the model's expected type. + This is useful for loading quantized models that use FP8 weights. + Defaults to `false` + """ @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{} def load_params!(model, input_template, path, opts \\ []) do @@ -36,6 +41,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do |> Keyword.validate!([ :log_params_diff, :backend, + :preserve_source_types, params_mapping: %{}, loader_fun: &Bumblebee.Conversion.PyTorchLoader.load!/1 ]) @@ -58,7 +64,17 @@ defmodule Bumblebee.Conversion.PyTorchParams do model_state = Axon.trace_init(model, input_template) params_expr = model_state.data - {params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping]) + preserve_source_types = opts[:preserve_source_types] || false + + {params, diff} = + init_params( + model, + params_expr, + pytorch_state, + opts[:params_mapping], + preserve_source_types + ) + model_state = %{model_state | data: params} params_complete? = diff.missing == [] and diff.mismatched == [] @@ -95,7 +111,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do Nx.Container.impl_for(value) != nil end - defp init_params(model, params_expr, pytorch_state, params_mapping) do + defp init_params(model, params_expr, pytorch_state, params_mapping, preserve_source_types) do layers = model |> Utils.Axon.nodes_with_names() @@ -103,7 +119,12 @@ defmodule Bumblebee.Conversion.PyTorchParams do prefixes = infer_prefixes(layers, pytorch_state, params_mapping) - diff = %{missing: [], mismatched: [], used_keys: []} + diff = %{ + missing: [], + mismatched: [], + used_keys: [], + preserve_source_types: preserve_source_types + } {params, diff} = layers @@ -155,7 +176,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do case verify_param_shape(param_expr, value) do :ok -> - value = ensure_type(param_expr, value) + value = ensure_type(param_expr, value, diff.preserve_source_types) {value, diff} {:error, expected, actual} -> @@ -507,11 +528,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do Utils.Nx.map(expr, &Nx.shape/1) end - defp ensure_type(param_expr, value) do + defp ensure_type(param_expr, value, preserve_source_types) do Utils.Nx.zip_with(param_expr, value, fn expr, tensor -> - case {Nx.type(expr), Nx.type(tensor)} do - {type, type} -> tensor - {expected, _actual} -> Nx.as_type(tensor, expected) + case {Nx.type(expr), Nx.type(tensor), preserve_source_types} do + {type, type, _} -> tensor + # Preserve FP8 E4M3FN types when preserve_source_types is enabled + {_expected, {:f8_e4m3fn, 8}, true} -> tensor + {expected, _actual, _} -> Nx.as_type(tensor, expected) end end) end diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 647b711a..e3f7a789 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -438,6 +438,128 @@ defmodule Bumblebee.Layers do |> Nx.add(bias) end + @doc """ + Adds an FP8-aware dense layer to the network. + + This layer supports optional scale_inv parameter for FP8 quantized weights. + When scale_inv is provided, it's applied to the matmul output to account + for FP8 quantization scaling. + + The kernel parameter uses standard dense layout (transposed from PyTorch). + + ## Options + + * `:name` - layer name + + * `:kernel_initializer` - initializer for `kernel` weights. + Defaults to `:glorot_uniform` + + * `:use_bias` - whether the layer should add bias to the output. + Defaults to `false` + + * `:block_size` - the block size used for FP8 quantization. + Defaults to 128 + + """ + def fp8_aware_dense(%Axon{} = x, units, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :name, + kernel_initializer: :glorot_uniform, + use_bias: false, + block_size: 128 + ]) + + name = opts[:name] + block_size = opts[:block_size] + + kernel_shape = &Axon.Shape.dense_kernel(&1, units) + bias_shape = &Axon.Shape.dense_bias(&1, units) + + # Scale shape: [input_blocks, output_blocks] where block_size is typically 128 + # This matches the transposed layout from PyTorch (kernel is transposed, so is scale) + # For non-FP8 models, scale_inv will be initialized to 1.0 + scale_shape = fn input_shape -> + in_features = elem(input_shape, tuple_size(input_shape) - 1) + out_features = units + # Round up to handle cases where dimensions aren't exact multiples of block_size + out_blocks = div(out_features + block_size - 1, block_size) + in_blocks = div(in_features + block_size - 1, block_size) + # Note: [in_blocks, out_blocks] to match transposed scale_inv from PyTorch + {in_blocks, out_blocks} + end + + kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) + + # scale_inv is initialized to 1.0 (identity) for non-FP8 models + # For FP8 models, it will be loaded from the checkpoint + scale_inv = Axon.param("scale_inv", scale_shape, initializer: :ones) + + {inputs, op} = + if opts[:use_bias] do + bias = Axon.param("bias", bias_shape, initializer: :zeros) + {[x, kernel, scale_inv, bias], &fp8_aware_dense_impl(&1, &2, &3, &4, &5, block_size)} + else + {[x, kernel, scale_inv], &fp8_aware_dense_impl(&1, &2, &3, nil, &4, block_size)} + end + + Axon.layer(op, inputs, name: name, op_name: :fp8_aware_dense) + end + + deftransformp fp8_aware_dense_impl(x, kernel, scale_inv, bias, _opts, block_size) do + # Dequantize the kernel using scale_inv before matmul + # kernel: [in_features, out_features] + # scale_inv: [in_blocks, out_blocks] (transposed from PyTorch layout) + # Each 128x128 block of the kernel should be multiplied by its scale + kernel_dequant = dequantize_kernel(kernel, scale_inv, block_size) + + # Do the matmul with dequantized kernel + # x: [batch, seq_len, in_features] + # kernel_dequant: [in_features, out_features] + # result: [batch, seq_len, out_features] + result = Nx.dot(x, [-1], kernel_dequant, [0]) + + # Add bias if present + if bias do + Nx.add(result, bias) + else + result + end + end + + defp dequantize_kernel(kernel, scale_inv, block_size) do + # kernel: [in_features, out_features] + # scale_inv: [in_blocks, out_blocks] where in_blocks = ceil(in_features/128) + # + # To dequantize: for each element kernel[i,o], multiply by scale_inv[i/128, o/128] + # This is done by expanding scale_inv to match kernel shape + + {in_features, out_features} = Nx.shape(kernel) + {in_blocks, out_blocks} = Nx.shape(scale_inv) + + # Expand scale_inv to [in_features, out_features] + # Each scale value is replicated block_size times in both dimensions + scale_expanded = + scale_inv + # Replicate along input dimension: [in_blocks, out_blocks] -> [in_blocks * block_size, out_blocks] + |> Nx.reshape({in_blocks, 1, out_blocks}) + |> Nx.broadcast({in_blocks, block_size, out_blocks}) + |> Nx.reshape({in_blocks * block_size, out_blocks}) + # Replicate along output dimension: [..., out_blocks] -> [..., out_blocks * block_size] + |> Nx.reshape({in_blocks * block_size, out_blocks, 1}) + |> Nx.broadcast({in_blocks * block_size, out_blocks, block_size}) + |> Nx.reshape({in_blocks * block_size, out_blocks * block_size}) + + # Slice to exact kernel dimensions (in case they're not exact multiples of block_size) + scale_expanded = + scale_expanded + |> Nx.slice([0, 0], [in_features, out_features]) + + # Convert kernel to higher precision for dequantization, then multiply by scale + kernel_f32 = Nx.as_type(kernel, {:f, 32}) + Nx.multiply(kernel_f32, scale_expanded) + end + @doc """ Adds a 1-dimensional convolution layer to the network. diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 188b0ffe..72a4141e 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -63,7 +63,8 @@ defmodule Bumblebee.Layers.Transformer do :block_type, :attention_scale, :query_norm, - :key_norm + :key_norm, + :attention_dense ] opts = @@ -354,7 +355,8 @@ defmodule Bumblebee.Layers.Transformer do attention_scale: nil, rotary_embedding: nil, query_norm: nil, - key_norm: nil + key_norm: nil, + attention_dense: nil ]) name = opts[:name] @@ -386,6 +388,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding = opts[:rotary_embedding] query_norm = opts[:query_norm] key_norm = opts[:key_norm] + attention_dense = opts[:attention_dense] ffn_fun = case ffn do @@ -446,6 +449,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding: rotary_embedding, query_norm: query_norm, key_norm: key_norm, + attention_dense: attention_dense, name: join(name, "self_attention") ) @@ -491,6 +495,7 @@ defmodule Bumblebee.Layers.Transformer do attention_window_size: attention_window_size, attention_scale: attention_scale, rotary_embedding: rotary_embedding, + attention_dense: attention_dense, name: join(name, "cross_attention") ) @@ -772,7 +777,8 @@ defmodule Bumblebee.Layers.Transformer do output_use_bias: true, rotary_embedding: nil, query_norm: nil, - key_norm: nil + key_norm: nil, + attention_dense: nil ]) attention_mask = opts[:attention_mask] @@ -792,6 +798,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding = opts[:rotary_embedding] query_norm = opts[:query_norm] key_norm = opts[:key_norm] + attention_dense = opts[:attention_dense] query_use_bias = opts[:query_use_bias] key_use_bias = opts[:key_use_bias] @@ -804,9 +811,18 @@ defmodule Bumblebee.Layers.Transformer do inner_size = num_heads * attention_head_size inner_kv_size = num_key_value_heads * attention_head_size + # Helper to create dense layer, using custom attention_dense if provided + dense_fn = fn input, units, dense_opts -> + if attention_dense do + attention_dense.(input, units, dense_opts) + else + Axon.dense(input, units, dense_opts) + end + end + query = query - |> Axon.dense(inner_size, + |> dense_fn.(inner_size, kernel_initializer: kernel_initializer, name: join(name, "query"), use_bias: query_use_bias @@ -815,7 +831,7 @@ defmodule Bumblebee.Layers.Transformer do key = key - |> Axon.dense(inner_kv_size, + |> dense_fn.(inner_kv_size, kernel_initializer: kernel_initializer, name: join(name, "key"), use_bias: key_use_bias @@ -824,7 +840,7 @@ defmodule Bumblebee.Layers.Transformer do value = value - |> Axon.dense(inner_kv_size, + |> dense_fn.(inner_kv_size, kernel_initializer: kernel_initializer, name: join(name, "value"), use_bias: value_use_bias @@ -937,7 +953,7 @@ defmodule Bumblebee.Layers.Transformer do attention_output = attention_output |> Layers.flatten_trailing() - |> Axon.dense(hidden_size, + |> dense_fn.(hidden_size, kernel_initializer: kernel_initializer, name: join(name, "output"), use_bias: output_use_bias diff --git a/lib/bumblebee/text/qwen3.ex b/lib/bumblebee/text/qwen3.ex index 568dd4e6..eec6a004 100644 --- a/lib/bumblebee/text/qwen3.ex +++ b/lib/bumblebee/text/qwen3.ex @@ -343,6 +343,11 @@ defmodule Bumblebee.Text.Qwen3 do &Layers.rms_norm(&1, epsilon: spec.layer_norm_epsilon, channel_index: -1, name: &2) end + # Attention dense function using fp8_aware_dense for FP8 model support + attention_dense_fn = fn input, units, dense_opts -> + Layers.fp8_aware_dense(input, units, dense_opts) + end + Layers.Transformer.blocks(hidden_state, num_blocks: spec.num_blocks, num_attention_heads: spec.num_attention_heads, @@ -373,6 +378,7 @@ defmodule Bumblebee.Text.Qwen3 do ], query_norm: query_norm, key_norm: key_norm, + attention_dense: attention_dense_fn, name: join(name, "blocks") ) end @@ -381,17 +387,26 @@ defmodule Bumblebee.Text.Qwen3 do name = opts[:name] activation = opts[:activation] + # Use fp8_aware_dense for FP8 model support + # For non-FP8 models, scale_inv will be initialized to 1.0 (identity) intermediate = - Axon.dense(hidden_state, intermediate_size, + Layers.fp8_aware_dense(hidden_state, intermediate_size, name: join(name, "intermediate"), use_bias: false ) - gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + gate = + Layers.fp8_aware_dense(hidden_state, intermediate_size, + name: join(name, "gate"), + use_bias: false + ) hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation)) - Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + Layers.fp8_aware_dense(hidden_state, output_size, + name: join(name, "output"), + use_bias: false + ) end defp language_modeling_head(hidden_state, spec, opts) do @@ -454,16 +469,83 @@ defmodule Bumblebee.Text.Qwen3 do def params_mapping(spec) do %{ "embedder.token_embedding" => "model.embed_tokens", - "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", - "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", - "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", - "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + # Attention layers with FP8 scale_inv support + # Note: Both kernel and scale_inv need to be transposed to match Axon layout + "decoder.blocks.{n}.self_attention.query" => %{ + "kernel" => { + [{"model.layers.{n}.self_attn.q_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.self_attn.q_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, + "decoder.blocks.{n}.self_attention.key" => %{ + "kernel" => { + [{"model.layers.{n}.self_attn.k_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.self_attn.k_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, + "decoder.blocks.{n}.self_attention.value" => %{ + "kernel" => { + [{"model.layers.{n}.self_attn.v_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.self_attn.v_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, + "decoder.blocks.{n}.self_attention.output" => %{ + "kernel" => { + [{"model.layers.{n}.self_attn.o_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.self_attn.o_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, "decoder.blocks.{n}.self_attention.query_norm" => "model.layers.{n}.self_attn.q_norm", "decoder.blocks.{n}.self_attention.key_norm" => "model.layers.{n}.self_attn.k_norm", "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", - "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", - "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", - "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + # FFN layers with FP8 scale_inv support + # Note: Both kernel and scale_inv need to be transposed to match Axon layout + "decoder.blocks.{n}.ffn.gate" => %{ + "kernel" => { + [{"model.layers.{n}.mlp.gate_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.mlp.gate_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, + "decoder.blocks.{n}.ffn.intermediate" => %{ + "kernel" => { + [{"model.layers.{n}.mlp.up_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.mlp.up_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, + "decoder.blocks.{n}.ffn.output" => %{ + "kernel" => { + [{"model.layers.{n}.mlp.down_proj", "weight"}], + fn [kernel] -> Nx.transpose(kernel) end + }, + "scale_inv" => { + [{"model.layers.{n}.mlp.down_proj", "weight_scale_inv"}], + fn [scale] -> Nx.transpose(scale) end + } + }, "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", "output_norm" => "model.norm", "language_modeling_head.output" => diff --git a/mix.exs b/mix.exs index 160056d9..b77ff6be 100644 --- a/mix.exs +++ b/mix.exs @@ -34,15 +34,17 @@ defmodule Bumblebee.MixProject do {:axon, "~> 0.7.0"}, # {:axon, github: "elixir-nx/axon", override: true}, {:tokenizers, "~> 0.4"}, - {:nx, "~> 0.9.0 or ~> 0.10.0"}, - {:exla, ">= 0.0.0", only: [:dev, :test]}, - {:torchx, ">= 0.0.0", only: [:dev, :test]}, - # {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, - # {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]}, - # {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, + # TODO: Replace git deps with hex versions once nx >= X.X.X and safetensors >= 0.2.0 are released + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]}, + {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, + # {:nx, "~> 0.9.0 or ~> 0.10.0"}, + # {:exla, ">= 0.0.0", only: [:dev, :test]}, + # {:torchx, ">= 0.0.0", only: [:dev, :test]}, {:nx_image, "~> 0.1.0"}, {:unpickler, "~> 0.1.0"}, - {:safetensors, "~> 0.1.3"}, + # TODO: Replace git dep with hex version once safetensors >= 0.2.0 is released + {:safetensors, github: "elixir-nx/safetensors"}, {:jason, "~> 1.4.0"}, {:unzip, "~> 0.12.0 or ~> 0.13.0"}, {:progress_bar, "~> 3.0"}, diff --git a/mix.lock b/mix.lock index a8d7496b..04594dcd 100644 --- a/mix.lock +++ b/mix.lock @@ -11,7 +11,7 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, "ex_doc": {:hex, :ex_doc, "0.39.1", "e19d356a1ba1e8f8cfc79ce1c3f83884b6abfcb79329d435d4bbb3e97ccc286e", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "8abf0ed3e3ca87c0847dfc4168ceab5bedfe881692f1b7c45f4a11b232806865"}, - "exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "29334ee430f0532387abd1b7ef9797c1ca76c12a", [sparse: "exla"]}, "fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"}, "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, @@ -20,7 +20,7 @@ "mime": {:hex, :mime, "2.0.7", "b8d739037be7cd402aee1ba0306edfdef982687ee7e9859bee6198c1e7e2f128", [:mix], [], "hexpm", "6171188e399ee16023ffc5b76ce445eb6d9672e2e241d2df6050f3c771e80ccd"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, - "nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "29334ee430f0532387abd1b7ef9797c1ca76c12a", [sparse: "nx"]}, "nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"}, "nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"}, "plug": {:hex, :plug, "1.18.1", "5067f26f7745b7e31bc3368bc1a2b818b9779faa959b49c934c17730efc911cf", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "57a57db70df2b422b564437d2d33cf8d33cd16339c1edb190cd11b1a3a546cc2"}, @@ -30,11 +30,11 @@ "progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"}, "ranch": {:hex, :ranch, "1.8.1", "208169e65292ac5d333d6cdbad49388c1ae198136e4697ae2f474697140f201c", [:make, :rebar3], [], "hexpm", "aed58910f4e21deea992a67bf51632b6d60114895eb03bb392bb733064594dd0"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"}, - "safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"}, + "safetensors": {:git, "https://github.com/elixir-nx/safetensors.git", "29df313b91aba3ddf9d85aa2f3445db5e8d2622b", []}, "stb_image": {:hex, :stb_image, "0.6.10", "76975279e2a130f53dc670bf6f6b1cdc4fbd7ab6293053e88e7fb6a7eae0e836", [:make, :mix], [{:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.8", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "26125372cfeda209084d3670417fab6819cfccd0e66c657678ecc48314369e8d"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, "tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"}, - "torchx": {:hex, :torchx, "0.10.2", "4b8529bfc4b0e641232497c99ef6d2508e652198840b212373333361352f0bae", [:mix], [{:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "cad541c64df8ddcbf50d9b0f212961632361a03050c8e01493f0fc8d4fed96d9"}, + "torchx": {:git, "https://github.com/elixir-nx/nx.git", "29334ee430f0532387abd1b7ef9797c1ca76c12a", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.13.0", "bf5ec6ac6063c69e6ec54c8b4a3b8dcd7a2719d28d10d7025776ab107957cde9", [:mix], [], "hexpm", "4bcb9892ecbf2042606b43ab685a1bffe03c14003e6246f5453db2c829237fd9"}, "xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"}, diff --git a/notebooks/qwen3.livemd b/notebooks/qwen3.livemd index c66fd3ce..4925fa69 100644 --- a/notebooks/qwen3.livemd +++ b/notebooks/qwen3.livemd @@ -16,6 +16,7 @@ Nx.global_default_backend(EXLA.Backend) In this notebook we explore the [Qwen3](https://qwenlm.github.io/blog/qwen3/) model family from Alibaba Cloud. Qwen3 is a series of large language models that includes: * **Text Generation** - Instruction-tuned models for conversational AI +* **FP8 Quantization** - Memory-efficient 8-bit floating point models * **Embeddings** - Dense vector representations for semantic search * **Rerankers** - Models to rerank search results for better relevance @@ -79,6 +80,67 @@ Nx.Serving.batched_run(Qwen3, prompt) |> Enum.each(&IO.write/1) +## Text Generation with FP8 Quantization + +Qwen3 models are also available in FP8 (8-bit floating point) quantized format, which significantly reduces memory usage while maintaining good quality. FP8 models use approximately half the memory of BF16 models. + +```elixir +repo = {:hf, "Qwen/Qwen3-4B-Instruct-2507-FP8"} + +{:ok, model_info} = Bumblebee.load_model(repo, + backend: EXLA.Backend, + preserve_source_types: true +) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +:ok +``` + +The key option here is `preserve_source_types: true`, which keeps the FP8 weights in their native format instead of converting them to the model's default type. The model will automatically dequantize the weights during inference. + +Configure generation and create a serving: + +```elixir +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: 256, + temperature: 0.7, + strategy: %{type: :multinomial_sampling, top_p: 0.8, top_k: 20} + ) + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 1024], + stream: true, + defn_options: [compiler: EXLA] + ) + +Kino.start_child({Nx.Serving, name: Qwen3FP8, serving: serving}) +``` + +Test the FP8 model with the same chat template: + +```elixir +user_input_fp8 = Kino.Input.textarea("User prompt (FP8)", default: "What are the benefits of quantized models?") +``` + +```elixir +user = Kino.Input.read(user_input_fp8) + +prompt = """ +<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +#{user}<|im_end|> +<|im_start|>assistant +""" + +Nx.Serving.batched_run(Qwen3FP8, prompt) |> Enum.each(&IO.write/1) +``` + + + ## Embeddings Qwen3 embedding models convert text into dense vector representations, useful for semantic search and similarity tasks. @@ -214,10 +276,11 @@ The reranker correctly identifies that the document directly answering "What is ## Summary -This notebook demonstrated three key capabilities of the Qwen3 model family: +This notebook demonstrated four key capabilities of the Qwen3 model family: 1. **Text Generation** - Conversational AI using instruction-tuned models -2. **Embeddings** - Creating semantic vector representations for similarity search -3. **Reranking** - Scoring and ranking documents by relevance to a query +2. **FP8 Quantization** - Memory-efficient inference using 8-bit floating point weights +3. **Embeddings** - Creating semantic vector representations for similarity search +4. **Reranking** - Scoring and ranking documents by relevance to a query -All three models work seamlessly with Bumblebee and can be used for various NLP applications. +All models work seamlessly with Bumblebee and can be used for various NLP applications. diff --git a/test/bumblebee/layers_test.exs b/test/bumblebee/layers_test.exs new file mode 100644 index 00000000..60c74947 --- /dev/null +++ b/test/bumblebee/layers_test.exs @@ -0,0 +1,158 @@ +defmodule Bumblebee.LayersTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + describe "fp8_aware_dense/3" do + test "dequantizes FP8 kernel with scale_inv" do + # Create a simple model with fp8_aware_dense + model = + Axon.input("input", shape: {nil, 4}) + |> Bumblebee.Layers.fp8_aware_dense(8, name: "dense", block_size: 2) + + # Create params with known values + # kernel: [4, 8] - input_features x output_features + # scale_inv: [2, 4] - ceil(4/2) x ceil(8/2) blocks + kernel = + Nx.tensor( + [ + [1, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 4, 5, 6, 7, 8] + ], + type: {:f, 32} + ) + + # Scale of 2.0 for all blocks means output should be 2x what it would be without scaling + scale_inv = + Nx.tensor( + [ + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0] + ], + type: {:f, 32} + ) + + params = %{ + "dense" => %{ + "kernel" => kernel, + "scale_inv" => scale_inv + } + } + + input = Nx.tensor([[1.0, 1.0, 1.0, 1.0]]) + + output = Axon.predict(model, params, %{"input" => input}) + + # Without scaling: input [1,1,1,1] dot kernel gives [4, 8, 12, 16, 20, 24, 28, 32] + # With scale_inv of 2.0: [8, 16, 24, 32, 40, 48, 56, 64] + expected = Nx.tensor([[8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 64.0]]) + + assert_all_close(output, expected) + end + + test "dequantizes with identity scale (1.0)" do + model = + Axon.input("input", shape: {nil, 4}) + |> Bumblebee.Layers.fp8_aware_dense(4, name: "dense", block_size: 2) + + kernel = + Nx.tensor( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ], + type: {:f, 32} + ) + + # Identity scale + scale_inv = + Nx.tensor( + [ + [1.0, 1.0], + [1.0, 1.0] + ], + type: {:f, 32} + ) + + params = %{ + "dense" => %{ + "kernel" => kernel, + "scale_inv" => scale_inv + } + } + + input = Nx.tensor([[2.0, 3.0, 4.0, 5.0]]) + output = Axon.predict(model, params, %{"input" => input}) + + # Identity matrix with scale 1.0 should return input unchanged + assert_all_close(output, input) + end + + test "handles non-block-aligned dimensions" do + # 3 input features, 5 output features with block_size 2 + # This tests the slicing logic for non-aligned dimensions + model = + Axon.input("input", shape: {nil, 3}) + |> Bumblebee.Layers.fp8_aware_dense(5, name: "dense", block_size: 2) + + # kernel: [3, 5] + kernel = Nx.broadcast(1.0, {3, 5}) + + # scale_inv: [ceil(3/2), ceil(5/2)] = [2, 3] + scale_inv = Nx.broadcast(1.0, {2, 3}) + + params = %{ + "dense" => %{ + "kernel" => kernel, + "scale_inv" => scale_inv + } + } + + input = Nx.tensor([[1.0, 1.0, 1.0]]) + output = Axon.predict(model, params, %{"input" => input}) + + # Sum of 3 ones = 3.0 for each output + expected = Nx.tensor([[3.0, 3.0, 3.0, 3.0, 3.0]]) + + assert_all_close(output, expected) + end + + test "includes bias when use_bias is true" do + model = + Axon.input("input", shape: {nil, 2}) + |> Bumblebee.Layers.fp8_aware_dense(2, name: "dense", block_size: 2, use_bias: true) + + kernel = + Nx.tensor( + [ + [1, 0], + [0, 1] + ], + type: {:f, 32} + ) + + scale_inv = Nx.tensor([[1.0]], type: {:f, 32}) + bias = Nx.tensor([10.0, 20.0], type: {:f, 32}) + + params = %{ + "dense" => %{ + "kernel" => kernel, + "scale_inv" => scale_inv, + "bias" => bias + } + } + + input = Nx.tensor([[1.0, 2.0]]) + output = Axon.predict(model, params, %{"input" => input}) + + # [1, 2] with identity kernel = [1, 2], plus bias [10, 20] = [11, 22] + expected = Nx.tensor([[11.0, 22.0]]) + + assert_all_close(output, expected) + end + end +end diff --git a/test/bumblebee/text/qwen3_test.exs b/test/bumblebee/text/qwen3_test.exs index 5134d6d5..677b4542 100644 --- a/test/bumblebee/text/qwen3_test.exs +++ b/test/bumblebee/text/qwen3_test.exs @@ -75,4 +75,33 @@ defmodule Bumblebee.Text.Qwen3Test do Nx.tensor([[-0.1487, -0.0071]]) ) end + + test ":for_causal_language_modeling with FP8 weights" do + assert {:ok, + %{model: model, params: %Axon.ModelState{data: params_data} = params, spec: spec}} = + Bumblebee.load_model( + {:hf, "roulis/tiny-fp8-qwen3"}, + preserve_source_types: true + ) + + assert %Bumblebee.Text.Qwen3{architecture: :for_causal_language_modeling} = spec + + # Verify FP8 weights are preserved + q_proj_kernel = params_data["decoder.blocks.0.self_attention.query"]["kernel"] + assert Nx.type(q_proj_kernel) == {:f8_e4m3fn, 8} + + # Verify scale_inv is loaded + q_proj_scale = params_data["decoder.blocks.0.self_attention.query"]["scale_inv"] + assert Nx.type(q_proj_scale) == {:f, 32} + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + # Model should run without error (dequantization happens internally) + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + end end