Skip to content

Commit e8e7f67

Browse files
committed
test: Add FP8 model tests and tiny model generator
- Add fp8_aware_dense layer unit tests - Add FP8 Qwen3 model loading test using roulis/tiny-fp8-qwen3 - Include Python script to generate tiny FP8 test models
1 parent cb36413 commit e8e7f67

File tree

3 files changed

+326
-0
lines changed

3 files changed

+326
-0
lines changed

test/bumblebee/layers_test.exs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
defmodule Bumblebee.LayersTest do
2+
use ExUnit.Case, async: true
3+
4+
import Bumblebee.TestHelpers
5+
6+
describe "fp8_aware_dense/3" do
7+
test "dequantizes FP8 kernel with scale_inv" do
8+
# Create a simple model with fp8_aware_dense
9+
model =
10+
Axon.input("input", shape: {nil, 4})
11+
|> Bumblebee.Layers.fp8_aware_dense(8, name: "dense", block_size: 2)
12+
13+
# Create params with known values
14+
# kernel: [4, 8] - input_features x output_features
15+
# scale_inv: [2, 4] - ceil(4/2) x ceil(8/2) blocks
16+
kernel = Nx.tensor([
17+
[1, 2, 3, 4, 5, 6, 7, 8],
18+
[1, 2, 3, 4, 5, 6, 7, 8],
19+
[1, 2, 3, 4, 5, 6, 7, 8],
20+
[1, 2, 3, 4, 5, 6, 7, 8]
21+
], type: {:f, 32})
22+
23+
# Scale of 2.0 for all blocks means output should be 2x what it would be without scaling
24+
scale_inv = Nx.tensor([
25+
[2.0, 2.0, 2.0, 2.0],
26+
[2.0, 2.0, 2.0, 2.0]
27+
], type: {:f, 32})
28+
29+
params = %{
30+
"dense" => %{
31+
"kernel" => kernel,
32+
"scale_inv" => scale_inv
33+
}
34+
}
35+
36+
input = Nx.tensor([[1.0, 1.0, 1.0, 1.0]])
37+
38+
output = Axon.predict(model, params, %{"input" => input})
39+
40+
# Without scaling: input [1,1,1,1] dot kernel gives [4, 8, 12, 16, 20, 24, 28, 32]
41+
# With scale_inv of 2.0: [8, 16, 24, 32, 40, 48, 56, 64]
42+
expected = Nx.tensor([[8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 64.0]])
43+
44+
assert_all_close(output, expected)
45+
end
46+
47+
test "dequantizes with identity scale (1.0)" do
48+
model =
49+
Axon.input("input", shape: {nil, 4})
50+
|> Bumblebee.Layers.fp8_aware_dense(4, name: "dense", block_size: 2)
51+
52+
kernel = Nx.tensor([
53+
[1, 0, 0, 0],
54+
[0, 1, 0, 0],
55+
[0, 0, 1, 0],
56+
[0, 0, 0, 1]
57+
], type: {:f, 32})
58+
59+
# Identity scale
60+
scale_inv = Nx.tensor([
61+
[1.0, 1.0],
62+
[1.0, 1.0]
63+
], type: {:f, 32})
64+
65+
params = %{
66+
"dense" => %{
67+
"kernel" => kernel,
68+
"scale_inv" => scale_inv
69+
}
70+
}
71+
72+
input = Nx.tensor([[2.0, 3.0, 4.0, 5.0]])
73+
output = Axon.predict(model, params, %{"input" => input})
74+
75+
# Identity matrix with scale 1.0 should return input unchanged
76+
assert_all_close(output, input)
77+
end
78+
79+
test "handles non-block-aligned dimensions" do
80+
# 3 input features, 5 output features with block_size 2
81+
# This tests the slicing logic for non-aligned dimensions
82+
model =
83+
Axon.input("input", shape: {nil, 3})
84+
|> Bumblebee.Layers.fp8_aware_dense(5, name: "dense", block_size: 2)
85+
86+
# kernel: [3, 5]
87+
kernel = Nx.broadcast(1.0, {3, 5})
88+
89+
# scale_inv: [ceil(3/2), ceil(5/2)] = [2, 3]
90+
scale_inv = Nx.broadcast(1.0, {2, 3})
91+
92+
params = %{
93+
"dense" => %{
94+
"kernel" => kernel,
95+
"scale_inv" => scale_inv
96+
}
97+
}
98+
99+
input = Nx.tensor([[1.0, 1.0, 1.0]])
100+
output = Axon.predict(model, params, %{"input" => input})
101+
102+
# Sum of 3 ones = 3.0 for each output
103+
expected = Nx.tensor([[3.0, 3.0, 3.0, 3.0, 3.0]])
104+
105+
assert_all_close(output, expected)
106+
end
107+
108+
test "includes bias when use_bias is true" do
109+
model =
110+
Axon.input("input", shape: {nil, 2})
111+
|> Bumblebee.Layers.fp8_aware_dense(2, name: "dense", block_size: 2, use_bias: true)
112+
113+
kernel = Nx.tensor([
114+
[1, 0],
115+
[0, 1]
116+
], type: {:f, 32})
117+
118+
scale_inv = Nx.tensor([[1.0]], type: {:f, 32})
119+
bias = Nx.tensor([10.0, 20.0], type: {:f, 32})
120+
121+
params = %{
122+
"dense" => %{
123+
"kernel" => kernel,
124+
"scale_inv" => scale_inv,
125+
"bias" => bias
126+
}
127+
}
128+
129+
input = Nx.tensor([[1.0, 2.0]])
130+
output = Axon.predict(model, params, %{"input" => input})
131+
132+
# [1, 2] with identity kernel = [1, 2], plus bias [10, 20] = [11, 22]
133+
expected = Nx.tensor([[11.0, 22.0]])
134+
135+
assert_all_close(output, expected)
136+
end
137+
end
138+
end

test/bumblebee/text/qwen3_test.exs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,32 @@ defmodule Bumblebee.Text.Qwen3Test do
7575
Nx.tensor([[-0.1487, -0.0071]])
7676
)
7777
end
78+
79+
test ":for_causal_language_modeling with FP8 weights" do
80+
assert {:ok, %{model: model, params: %Axon.ModelState{data: params_data} = params, spec: spec}} =
81+
Bumblebee.load_model(
82+
{:hf, "roulis/tiny-fp8-qwen3"},
83+
preserve_source_types: true
84+
)
85+
86+
assert %Bumblebee.Text.Qwen3{architecture: :for_causal_language_modeling} = spec
87+
88+
# Verify FP8 weights are preserved
89+
q_proj_kernel = params_data["decoder.blocks.0.self_attention.query"]["kernel"]
90+
assert Nx.type(q_proj_kernel) == {:f8_e4m3fn, 8}
91+
92+
# Verify scale_inv is loaded
93+
q_proj_scale = params_data["decoder.blocks.0.self_attention.query"]["scale_inv"]
94+
assert Nx.type(q_proj_scale) == {:f, 32}
95+
96+
inputs = %{
97+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
98+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
99+
}
100+
101+
# Model should run without error (dequantization happens internally)
102+
outputs = Axon.predict(model, params, inputs)
103+
104+
assert Nx.shape(outputs.logits) == {1, 10, 1024}
105+
end
78106
end
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
Generate a tiny FP8 Qwen3 model for testing Bumblebee's FP8 support.
3+
4+
This creates a minimal model with:
5+
- FP8 E4M3FN weights for linear layers
6+
- Corresponding weight_scale_inv tensors (128x128 block scaling)
7+
- Saved in safetensors format
8+
9+
Usage:
10+
python generate_fp8_qwen3.py
11+
# Then upload to HuggingFace: huggingface-cli upload bumblebee-testing/tiny-random-Qwen3ForCausalLM-FP8 ./tiny-fp8-qwen3
12+
"""
13+
14+
import torch
15+
import json
16+
import os
17+
from safetensors.torch import save_file
18+
19+
# Tiny model config matching existing tiny-random-Qwen3ForCausalLM
20+
CONFIG = {
21+
"architectures": ["Qwen3ForCausalLM"],
22+
"hidden_size": 32,
23+
"intermediate_size": 64,
24+
"num_attention_heads": 4,
25+
"num_hidden_layers": 2,
26+
"num_key_value_heads": 2,
27+
"vocab_size": 1024,
28+
"head_dim": 8, # hidden_size / num_attention_heads
29+
"rms_norm_eps": 1e-6,
30+
"rope_theta": 1000000.0,
31+
"max_position_embeddings": 512,
32+
"torch_dtype": "float8_e4m3fn",
33+
"model_type": "qwen3",
34+
"use_qk_norm": True,
35+
"tie_word_embeddings": True,
36+
"quantization_config": {
37+
"quant_method": "fp8",
38+
"weight_block_size": [128, 128]
39+
}
40+
}
41+
42+
BLOCK_SIZE = 128
43+
44+
45+
def create_fp8_weight(shape, seed=42):
46+
"""Create a random FP8 E4M3FN weight tensor."""
47+
torch.manual_seed(seed)
48+
# Create random values in valid FP8 E4M3FN range (-448 to 448)
49+
weight_f32 = torch.randn(shape) * 0.1
50+
weight_fp8 = weight_f32.to(torch.float8_e4m3fn)
51+
return weight_fp8
52+
53+
54+
def create_scale_inv(weight_shape):
55+
"""Create scale_inv tensor for block-wise dequantization.
56+
57+
Shape: [ceil(out_features/128), ceil(in_features/128)]
58+
For testing, use scale of 1.0 (identity) so dequantized = original.
59+
"""
60+
out_features, in_features = weight_shape
61+
out_blocks = (out_features + BLOCK_SIZE - 1) // BLOCK_SIZE
62+
in_blocks = (in_features + BLOCK_SIZE - 1) // BLOCK_SIZE
63+
# Use 1.0 for identity scaling (easier to verify in tests)
64+
return torch.ones(out_blocks, in_blocks, dtype=torch.float32)
65+
66+
67+
def generate_model():
68+
hidden_size = CONFIG["hidden_size"]
69+
intermediate_size = CONFIG["intermediate_size"]
70+
num_heads = CONFIG["num_attention_heads"]
71+
num_kv_heads = CONFIG["num_key_value_heads"]
72+
head_dim = CONFIG["head_dim"]
73+
vocab_size = CONFIG["vocab_size"]
74+
num_layers = CONFIG["num_hidden_layers"]
75+
76+
tensors = {}
77+
seed = 0
78+
79+
# Embedding (not quantized)
80+
tensors["model.embed_tokens.weight"] = torch.randn(vocab_size, hidden_size)
81+
82+
for layer_idx in range(num_layers):
83+
prefix = f"model.layers.{layer_idx}"
84+
85+
# Self-attention projections (FP8 quantized)
86+
q_size = num_heads * head_dim
87+
kv_size = num_kv_heads * head_dim
88+
89+
# Q projection
90+
tensors[f"{prefix}.self_attn.q_proj.weight"] = create_fp8_weight((q_size, hidden_size), seed)
91+
seed += 1
92+
tensors[f"{prefix}.self_attn.q_proj.weight_scale_inv"] = create_scale_inv((q_size, hidden_size))
93+
94+
# K projection
95+
tensors[f"{prefix}.self_attn.k_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
96+
seed += 1
97+
tensors[f"{prefix}.self_attn.k_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
98+
99+
# V projection
100+
tensors[f"{prefix}.self_attn.v_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
101+
seed += 1
102+
tensors[f"{prefix}.self_attn.v_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
103+
104+
# O projection
105+
tensors[f"{prefix}.self_attn.o_proj.weight"] = create_fp8_weight((hidden_size, q_size), seed)
106+
seed += 1
107+
tensors[f"{prefix}.self_attn.o_proj.weight_scale_inv"] = create_scale_inv((hidden_size, q_size))
108+
109+
# QK norms (not quantized)
110+
tensors[f"{prefix}.self_attn.q_norm.weight"] = torch.ones(head_dim)
111+
tensors[f"{prefix}.self_attn.k_norm.weight"] = torch.ones(head_dim)
112+
113+
# MLP (FP8 quantized)
114+
tensors[f"{prefix}.mlp.gate_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
115+
seed += 1
116+
tensors[f"{prefix}.mlp.gate_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
117+
118+
tensors[f"{prefix}.mlp.up_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
119+
seed += 1
120+
tensors[f"{prefix}.mlp.up_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
121+
122+
tensors[f"{prefix}.mlp.down_proj.weight"] = create_fp8_weight((hidden_size, intermediate_size), seed)
123+
seed += 1
124+
tensors[f"{prefix}.mlp.down_proj.weight_scale_inv"] = create_scale_inv((hidden_size, intermediate_size))
125+
126+
# Layer norms (not quantized)
127+
tensors[f"{prefix}.input_layernorm.weight"] = torch.ones(hidden_size)
128+
tensors[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hidden_size)
129+
130+
# Final norm (not quantized)
131+
tensors["model.norm.weight"] = torch.ones(hidden_size)
132+
133+
# LM head (can be tied to embeddings, but we include it for completeness)
134+
# Not quantized since it shares with embeddings
135+
136+
return tensors
137+
138+
139+
def main():
140+
output_dir = "tiny-fp8-qwen3"
141+
os.makedirs(output_dir, exist_ok=True)
142+
143+
# Generate model tensors
144+
tensors = generate_model()
145+
146+
# Save as safetensors
147+
save_file(tensors, os.path.join(output_dir, "model.safetensors"))
148+
149+
# Save config
150+
with open(os.path.join(output_dir, "config.json"), "w") as f:
151+
json.dump(CONFIG, f, indent=2)
152+
153+
print(f"Model saved to {output_dir}/")
154+
print(f"Total tensors: {len(tensors)}")
155+
print("\nTo upload to HuggingFace:")
156+
print(f" huggingface-cli upload bumblebee-testing/tiny-random-Qwen3ForCausalLM-FP8 {output_dir}")
157+
158+
159+
if __name__ == "__main__":
160+
main()

0 commit comments

Comments
 (0)