Skip to content

[Feat] add bf16 sft to mxfp4 conversion #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions scripts/bf16_to_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from argparse import ArgumentParser
from typing import Optional, Tuple

from glob import glob
import json
import os
import re

import torch
from tqdm import tqdm
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import Cache

# GPT-OSS
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM


from safetensors.torch import load_file, save_file


# NOTE (yiakwy) : for quick verification purpose
# from simple_py_mxfp4 import quantize_bf16_mxfp4

from gpt_oss_triton_mxfp4 import quantize_bf16_mxfp4

def has_tensor(weight_map, loaded_files, mxfp4_path, tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.

Args:
tensor_name (str): The name of the tensor to retrieve.

Returns:
torch.Tensor: The retrieved tensor.

Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(mxfp4_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]


def quantize(bf16_path, mxfp4_path, ref_weights_scale_inv_map_path=None):
ref_weights_scale_inv_map_f = os.path.join(
ref_weights_scale_inv_map_path, "weight_with_scale_inv_map.index.json"
)
with open(ref_weights_scale_inv_map_f, "r") as f:
s_model_index = json.load(f)
ref_weights_scale_inv_map = s_model_index["weight_with_scale_inv_map"]

os.makedirs(mxfp4_path, exist_ok=True)

model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]

# Cache for loaded safetensor files
loaded_files = {}
bf16_weight_names = []
bf16_weight_scales = {}

safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()

for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict

new_state_dict = {}
for weight_name, weight in current_state_dict.items():
block_name = f"{weight_name}_blocks"
if (
ref_weights_scale_inv_map is not None
and ref_weights_scale_inv_map.get(block_name, None) is not None
):
scale_name = f"{weight_name}_scales"

bf16_weight_names.append(weight_name)
bf16_weight_scales[scale_name] = file_name
weight_transpose = weight.permute(0, 2, 1).contiguous()
mxfp4_weight, mxfp4_scale = quantize_bf16_mxfp4(weight_transpose, 32)
new_state_dict[block_name] = mxfp4_weight.view(*mxfp4_weight.shape[:-1], -1, 16).contiguous()
new_state_dict[scale_name] = mxfp4_scale.contiguous()
else:
print(f"skipping {weight_name} dtype={weight.dtype}...")
new_state_dict[weight_name] = weight

new_safetensor_file = os.path.join(mxfp4_path, file_name)
save_file(new_state_dict, new_safetensor_file)

del new_state_dict

if len(loaded_files) > 1:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()

# Update model index
new_model_index_file = os.path.join(mxfp4_path, "model.safetensors.index.json")

for weight_name in bf16_weight_names:
scale_name = f"{weight_name}_scales"
block_name = f"{weight_name}_blocks"

weight_map[scale_name] = bf16_weight_scales[scale_name]
weight_map[block_name] = weight_map[weight_name]

weight_map.pop(weight_name)

with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)


def read_mxfp4_list(bf16_path):
model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
mxfp4_weights_inv_map = {}

# Cache for loaded safetensor files
loaded_files = {}
mxfp4_weights_name = []

safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()

for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict

for weight_name, weight in current_state_dict.items():
if weight_name.endswith("scales"):
print(f"skipping {weight_name} dtype={weight.dtype}...")
continue
elif weight.element_size() == 1: # MXFP4
scale_name = weight_name.replace("blocks", "scales")
try:
weight_scales = has_tensor(
weight_map, loaded_files, bf16_path, scale_name
)
mxfp4_weights_name.append(weight_name)
mxfp4_weights_inv_map[weight_name] = weight_map[scale_name]
except KeyError:
print(
f"Warning: Missing scales tensor for {weight_name}, skipping conversion ..."
)
else:
print(f"skipping {weight_name} dtype={weight.dtype}...")

if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()

weights_with_scale_inv = os.path.join(
bf16_path, "weight_with_scale_inv_map.index.json"
)
with open(weights_with_scale_inv, "w") as f:
json.dump(
{"metadata": {}, "weight_with_scale_inv_map": mxfp4_weights_inv_map},
f,
indent=2,
)


def _verify_tokenizer_and_model(hf_tokenizer, model):
texts = ["你是谁?"] # ["Give me a short introduction to large language model.", ]
messages = [
{"role": "user", "content": text} for text in texts
]

prompts = hf_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)

model_inputs = hf_tokenizer([prompts], return_tensors="pt").to(model.device)
outputs_ids = model.generate(**model_inputs, max_new_tokens=256)

outputs_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, outputs_ids)
]

response = hf_tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)[0]
print(f"response : {response}")


def verify_tokenizer_and_model(hf_tokenizer_path, model):
hf_tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path)

_verify_tokenizer_and_model(hf_tokenizer, model)


def load_and_verify_hf_model(source_model):
model = AutoModelForCausalLM.from_pretrained(
source_model, torch_dtype="auto", device_map="auto"
)

verify_tokenizer_and_model(source_model, model)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--source_model", default=None, type=str, required=False, help="source model."
)
parser.add_argument(
"--output_dir", default=None, type=str, required=False, help="Where to save the converted model."
)
parser.add_argument(
"--get_scaled_weights", action="store_true", required=False, help="get scaled weights"
)
args = parser.parse_args()

if not args.output_dir:
if args.get_scaled_weights:
read_mxfp4_list(args.source_model)
else:
load_and_verify_hf_model(args.source_model)
else:
quantize(args.source_model, args.output_dir, ref_weights_scale_inv_map_path="/root/models/gpt-oss-120b")
9 changes: 9 additions & 0 deletions scripts/gpt_oss_triton_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
import triton

import triton_kernels
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp

def quantize_bf16_mxfp4(w, block_size=None):
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=-1)
return w, w_scale
124 changes: 124 additions & 0 deletions scripts/simple_py_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch


# the functions are adapted from https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/0bea1c31d75761002aad4290e572cf7c512d8b3a/modelopt/torch/quantization/qtensor/mxfp4_tensor.py#L25

E2M1_max = 6.0

E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]

# TODO (yiakwy) : create from E2M1_values
FP4_VALUES = [
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]

E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])

def pack_uint4x2_to_uint8(x):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
new_data = right_side.clone() << 4 # Put odd indices (higher addresses) in high bits
new_data[..., : left_side.shape[-1]] += left_side # Put even indices in low bits
return new_data

def cast_fp4(x):
sign = torch.sign(x)
sign_bit = (2 - sign) // 2
ord_ = torch.sum(
(x.abs().unsqueeze(-1) - E2M1_bounds.to(x.device)) > 0, dim=-1
)
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
return fp4_val

# convert bf16 tensor to uint8
def quantize_bf16_mxfp4(input : torch.Tensor, block_size : int | None):
block_size = block_size or 32

input = input.view(-1, block_size)

input_amax = input.abs().max(dim=-1, keepdim=True).values
descale = input_amax / E2M1_max

min_value = torch.tensor(-127.0, device=descale.device)
e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))

original_shape = input.shape
input = (input / torch.exp2(e8m0_scale)).view(original_shape)
input_q = cast_fp4(input)
input_q = pack_uint4x2_to_uint8(input_q)

e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
return input_q, e8m0_scale


# the function is adapted from GPT_OSS repo
def convert_fp4_bf16(
blocks,
scales,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 32768 * 1024,
) -> torch.Tensor:
import math

# Check if blocks and scales are on CPU, and move to GPU if so
if not blocks.is_cuda and torch.cuda.is_available():
blocks = blocks.cuda()
scales = scales.cuda()

scales = scales.to(torch.int32) - 127

assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"

lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)

*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G

blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)

out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)

for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)

blk = blocks[r0:r1]
exp = scales[r0:r1]

# nibble indices -> int64
idx_lo = (blk & 0x0F).to(torch.long)
idx_hi = (blk >> 4).to(torch.long)

sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]

torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp, sub

out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)

# TODO: Delete after making sure this is not necessary! since we go back to cpu in the end in create_quantized_param using .to(target_device)
# Move back to CPU if needed
# if need_to_move_back:
# out = out.cpu()
del blocks, scales, lut
return out