Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ defmodule Bumblebee do
:params_filename,
:log_params_diff,
:backend,
:type
:type,
:preserve_source_types
])

with {:ok, repo_files} <- get_repo_files(repository),
Expand Down Expand Up @@ -654,7 +655,7 @@ defmodule Bumblebee do
[
params_mapping: params_mapping,
loader_fun: loader_fun
] ++ Keyword.take(opts, [:backend, :log_params_diff])
] ++ Keyword.take(opts, [:backend, :log_params_diff, :preserve_source_types])

params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts)
{:ok, params}
Expand Down
39 changes: 31 additions & 8 deletions lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ defmodule Bumblebee.Conversion.PyTorchParams do
and loads the params file. Defaults to
`Bumblebee.Conversion.PyTorchLoader.load!/1`

* `:preserve_source_types` - when `true`, preserves FP8 types from the
source file instead of converting them to the model's expected type.
This is useful for loading quantized models that use FP8 weights.
Defaults to `false`

"""
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{}
def load_params!(model, input_template, path, opts \\ []) do
Expand All @@ -36,6 +41,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
|> Keyword.validate!([
:log_params_diff,
:backend,
:preserve_source_types,
params_mapping: %{},
loader_fun: &Bumblebee.Conversion.PyTorchLoader.load!/1
])
Expand All @@ -58,7 +64,17 @@ defmodule Bumblebee.Conversion.PyTorchParams do
model_state = Axon.trace_init(model, input_template)

params_expr = model_state.data
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping])
preserve_source_types = opts[:preserve_source_types] || false

{params, diff} =
init_params(
model,
params_expr,
pytorch_state,
opts[:params_mapping],
preserve_source_types
)

model_state = %{model_state | data: params}

params_complete? = diff.missing == [] and diff.mismatched == []
Expand Down Expand Up @@ -95,15 +111,20 @@ defmodule Bumblebee.Conversion.PyTorchParams do
Nx.Container.impl_for(value) != nil
end

defp init_params(model, params_expr, pytorch_state, params_mapping) do
defp init_params(model, params_expr, pytorch_state, params_mapping, preserve_source_types) do
layers =
model
|> Utils.Axon.nodes_with_names()
|> Enum.filter(fn {layer, _name} -> layer.parameters != [] end)

prefixes = infer_prefixes(layers, pytorch_state, params_mapping)

diff = %{missing: [], mismatched: [], used_keys: []}
diff = %{
missing: [],
mismatched: [],
used_keys: [],
preserve_source_types: preserve_source_types
}

{params, diff} =
layers
Expand Down Expand Up @@ -155,7 +176,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do

case verify_param_shape(param_expr, value) do
:ok ->
value = ensure_type(param_expr, value)
value = ensure_type(param_expr, value, diff.preserve_source_types)
{value, diff}

{:error, expected, actual} ->
Expand Down Expand Up @@ -507,11 +528,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do
Utils.Nx.map(expr, &Nx.shape/1)
end

defp ensure_type(param_expr, value) do
defp ensure_type(param_expr, value, preserve_source_types) do
Utils.Nx.zip_with(param_expr, value, fn expr, tensor ->
case {Nx.type(expr), Nx.type(tensor)} do
{type, type} -> tensor
{expected, _actual} -> Nx.as_type(tensor, expected)
case {Nx.type(expr), Nx.type(tensor), preserve_source_types} do
{type, type, _} -> tensor
# Preserve FP8 E4M3FN types when preserve_source_types is enabled
{_expected, {:f8_e4m3fn, 8}, true} -> tensor
Comment on lines +535 to +536
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely don't want to do this here, because Axon may cast and it can lead to inconsistent behaviour (see #311). Ideally we want to apply an Axon.MixedPrecision policy, but we cannot determine it upfront. Also Axon policies apply per layer, but in this case we may have a layer where each param has different type. I need to think about the best way to address it and the loading API we should have.

{expected, _actual, _} -> Nx.as_type(tensor, expected)
end
end)
end
Expand Down
122 changes: 122 additions & 0 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,128 @@ defmodule Bumblebee.Layers do
|> Nx.add(bias)
end

@doc """
Adds an FP8-aware dense layer to the network.

This layer supports optional scale_inv parameter for FP8 quantized weights.
When scale_inv is provided, it's applied to the matmul output to account
for FP8 quantization scaling.

The kernel parameter uses standard dense layout (transposed from PyTorch).

## Options

* `:name` - layer name

* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`

* `:use_bias` - whether the layer should add bias to the output.
Defaults to `false`

* `:block_size` - the block size used for FP8 quantization.
Defaults to 128

"""
def fp8_aware_dense(%Axon{} = x, units, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
kernel_initializer: :glorot_uniform,
use_bias: false,
block_size: 128
])

name = opts[:name]
block_size = opts[:block_size]

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

# Scale shape: [input_blocks, output_blocks] where block_size is typically 128
# This matches the transposed layout from PyTorch (kernel is transposed, so is scale)
# For non-FP8 models, scale_inv will be initialized to 1.0
scale_shape = fn input_shape ->
in_features = elem(input_shape, tuple_size(input_shape) - 1)
out_features = units
# Round up to handle cases where dimensions aren't exact multiples of block_size
out_blocks = div(out_features + block_size - 1, block_size)
in_blocks = div(in_features + block_size - 1, block_size)
# Note: [in_blocks, out_blocks] to match transposed scale_inv from PyTorch
{in_blocks, out_blocks}
end

kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

# scale_inv is initialized to 1.0 (identity) for non-FP8 models
# For FP8 models, it will be loaded from the checkpoint
scale_inv = Axon.param("scale_inv", scale_shape, initializer: :ones)

{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: :zeros)
{[x, kernel, scale_inv, bias], &fp8_aware_dense_impl(&1, &2, &3, &4, &5, block_size)}
else
{[x, kernel, scale_inv], &fp8_aware_dense_impl(&1, &2, &3, nil, &4, block_size)}
end

Axon.layer(op, inputs, name: name, op_name: :fp8_aware_dense)
end

deftransformp fp8_aware_dense_impl(x, kernel, scale_inv, bias, _opts, block_size) do
# Dequantize the kernel using scale_inv before matmul
# kernel: [in_features, out_features]
# scale_inv: [in_blocks, out_blocks] (transposed from PyTorch layout)
# Each 128x128 block of the kernel should be multiplied by its scale
kernel_dequant = dequantize_kernel(kernel, scale_inv, block_size)

# Do the matmul with dequantized kernel
# x: [batch, seq_len, in_features]
# kernel_dequant: [in_features, out_features]
# result: [batch, seq_len, out_features]
result = Nx.dot(x, [-1], kernel_dequant, [0])

# Add bias if present
if bias do
Nx.add(result, bias)
else
result
end
end

defp dequantize_kernel(kernel, scale_inv, block_size) do
# kernel: [in_features, out_features]
# scale_inv: [in_blocks, out_blocks] where in_blocks = ceil(in_features/128)
#
# To dequantize: for each element kernel[i,o], multiply by scale_inv[i/128, o/128]
# This is done by expanding scale_inv to match kernel shape

{in_features, out_features} = Nx.shape(kernel)
{in_blocks, out_blocks} = Nx.shape(scale_inv)

# Expand scale_inv to [in_features, out_features]
# Each scale value is replicated block_size times in both dimensions
scale_expanded =
scale_inv
# Replicate along input dimension: [in_blocks, out_blocks] -> [in_blocks * block_size, out_blocks]
|> Nx.reshape({in_blocks, 1, out_blocks})
|> Nx.broadcast({in_blocks, block_size, out_blocks})
|> Nx.reshape({in_blocks * block_size, out_blocks})
# Replicate along output dimension: [..., out_blocks] -> [..., out_blocks * block_size]
|> Nx.reshape({in_blocks * block_size, out_blocks, 1})
|> Nx.broadcast({in_blocks * block_size, out_blocks, block_size})
|> Nx.reshape({in_blocks * block_size, out_blocks * block_size})

# Slice to exact kernel dimensions (in case they're not exact multiples of block_size)
scale_expanded =
scale_expanded
|> Nx.slice([0, 0], [in_features, out_features])

# Convert kernel to higher precision for dequantization, then multiply by scale
kernel_f32 = Nx.as_type(kernel, {:f, 32})
Nx.multiply(kernel_f32, scale_expanded)
end

@doc """
Adds a 1-dimensional convolution layer to the network.

Expand Down
30 changes: 23 additions & 7 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ defmodule Bumblebee.Layers.Transformer do
:block_type,
:attention_scale,
:query_norm,
:key_norm
:key_norm,
:attention_dense
]

opts =
Expand Down Expand Up @@ -354,7 +355,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_scale: nil,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
attention_dense: nil
])

name = opts[:name]
Expand Down Expand Up @@ -386,6 +388,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]
attention_dense = opts[:attention_dense]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -446,6 +449,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding: rotary_embedding,
query_norm: query_norm,
key_norm: key_norm,
attention_dense: attention_dense,
name: join(name, "self_attention")
)

Expand Down Expand Up @@ -491,6 +495,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_window_size: attention_window_size,
attention_scale: attention_scale,
rotary_embedding: rotary_embedding,
attention_dense: attention_dense,
name: join(name, "cross_attention")
)

Expand Down Expand Up @@ -772,7 +777,8 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
attention_dense: nil
])

attention_mask = opts[:attention_mask]
Expand All @@ -792,6 +798,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]
attention_dense = opts[:attention_dense]

query_use_bias = opts[:query_use_bias]
key_use_bias = opts[:key_use_bias]
Expand All @@ -804,9 +811,18 @@ defmodule Bumblebee.Layers.Transformer do
inner_size = num_heads * attention_head_size
inner_kv_size = num_key_value_heads * attention_head_size

# Helper to create dense layer, using custom attention_dense if provided
dense_fn = fn input, units, dense_opts ->
if attention_dense do
attention_dense.(input, units, dense_opts)
else
Axon.dense(input, units, dense_opts)
end
end

query =
query
|> Axon.dense(inner_size,
|> dense_fn.(inner_size,
kernel_initializer: kernel_initializer,
name: join(name, "query"),
use_bias: query_use_bias
Expand All @@ -815,7 +831,7 @@ defmodule Bumblebee.Layers.Transformer do

key =
key
|> Axon.dense(inner_kv_size,
|> dense_fn.(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "key"),
use_bias: key_use_bias
Expand All @@ -824,7 +840,7 @@ defmodule Bumblebee.Layers.Transformer do

value =
value
|> Axon.dense(inner_kv_size,
|> dense_fn.(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "value"),
use_bias: value_use_bias
Expand Down Expand Up @@ -937,7 +953,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_output =
attention_output
|> Layers.flatten_trailing()
|> Axon.dense(hidden_size,
|> dense_fn.(hidden_size,
kernel_initializer: kernel_initializer,
name: join(name, "output"),
use_bias: output_use_bias
Expand Down
Loading
Loading