diff --git a/scripts/bf16_to_mxfp4.py b/scripts/bf16_to_mxfp4.py new file mode 100644 index 0000000..531f26e --- /dev/null +++ b/scripts/bf16_to_mxfp4.py @@ -0,0 +1,230 @@ +from argparse import ArgumentParser +from typing import Optional, Tuple + +from glob import glob +import json +import os +import re + +import torch +from tqdm import tqdm +import transformers +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.cache_utils import Cache + +# GPT-OSS +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + +from safetensors.torch import load_file, save_file + + +# NOTE (yiakwy) : for quick verification purpose +# from simple_py_mxfp4 import quantize_bf16_mxfp4 + +from gpt_oss_triton_mxfp4 import quantize_bf16_mxfp4 + +def has_tensor(weight_map, loaded_files, mxfp4_path, tensor_name): + """ + Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. + + Args: + tensor_name (str): The name of the tensor to retrieve. + + Returns: + torch.Tensor: The retrieved tensor. + + Raises: + KeyError: If the tensor does not exist in the safetensor file. + """ + file_name = weight_map[tensor_name] + if file_name not in loaded_files: + file_path = os.path.join(mxfp4_path, file_name) + loaded_files[file_name] = load_file(file_path, device="cuda") + return loaded_files[file_name][tensor_name] + + +def quantize(bf16_path, mxfp4_path, ref_weights_scale_inv_map_path=None): + ref_weights_scale_inv_map_f = os.path.join( + ref_weights_scale_inv_map_path, "weight_with_scale_inv_map.index.json" + ) + with open(ref_weights_scale_inv_map_f, "r") as f: + s_model_index = json.load(f) + ref_weights_scale_inv_map = s_model_index["weight_with_scale_inv_map"] + + os.makedirs(mxfp4_path, exist_ok=True) + + model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + # Cache for loaded safetensor files + loaded_files = {} + bf16_weight_names = [] + bf16_weight_scales = {} + + safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors"))) + safetensor_files.sort() + + for safetensor_file in tqdm(safetensor_files): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + loaded_files[file_name] = current_state_dict + + new_state_dict = {} + for weight_name, weight in current_state_dict.items(): + block_name = f"{weight_name}_blocks" + if ( + ref_weights_scale_inv_map is not None + and ref_weights_scale_inv_map.get(block_name, None) is not None + ): + scale_name = f"{weight_name}_scales" + + bf16_weight_names.append(weight_name) + bf16_weight_scales[scale_name] = file_name + weight_transpose = weight.permute(0, 2, 1).contiguous() + mxfp4_weight, mxfp4_scale = quantize_bf16_mxfp4(weight_transpose, 32) + new_state_dict[block_name] = mxfp4_weight.view(*mxfp4_weight.shape[:-1], -1, 16).contiguous() + new_state_dict[scale_name] = mxfp4_scale.contiguous() + else: + print(f"skipping {weight_name} dtype={weight.dtype}...") + new_state_dict[weight_name] = weight + + new_safetensor_file = os.path.join(mxfp4_path, file_name) + save_file(new_state_dict, new_safetensor_file) + + del new_state_dict + + if len(loaded_files) > 1: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + + # Update model index + new_model_index_file = os.path.join(mxfp4_path, "model.safetensors.index.json") + + for weight_name in bf16_weight_names: + scale_name = f"{weight_name}_scales" + block_name = f"{weight_name}_blocks" + + weight_map[scale_name] = bf16_weight_scales[scale_name] + weight_map[block_name] = weight_map[weight_name] + + weight_map.pop(weight_name) + + with open(new_model_index_file, "w") as f: + json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + + +def read_mxfp4_list(bf16_path): + model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + mxfp4_weights_inv_map = {} + + # Cache for loaded safetensor files + loaded_files = {} + mxfp4_weights_name = [] + + safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors"))) + safetensor_files.sort() + + for safetensor_file in tqdm(safetensor_files): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + loaded_files[file_name] = current_state_dict + + for weight_name, weight in current_state_dict.items(): + if weight_name.endswith("scales"): + print(f"skipping {weight_name} dtype={weight.dtype}...") + continue + elif weight.element_size() == 1: # MXFP4 + scale_name = weight_name.replace("blocks", "scales") + try: + weight_scales = has_tensor( + weight_map, loaded_files, bf16_path, scale_name + ) + mxfp4_weights_name.append(weight_name) + mxfp4_weights_inv_map[weight_name] = weight_map[scale_name] + except KeyError: + print( + f"Warning: Missing scales tensor for {weight_name}, skipping conversion ..." + ) + else: + print(f"skipping {weight_name} dtype={weight.dtype}...") + + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + + weights_with_scale_inv = os.path.join( + bf16_path, "weight_with_scale_inv_map.index.json" + ) + with open(weights_with_scale_inv, "w") as f: + json.dump( + {"metadata": {}, "weight_with_scale_inv_map": mxfp4_weights_inv_map}, + f, + indent=2, + ) + + +def _verify_tokenizer_and_model(hf_tokenizer, model): + texts = ["你是谁?"] # ["Give me a short introduction to large language model.", ] + messages = [ + {"role": "user", "content": text} for text in texts + ] + + prompts = hf_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True) + + model_inputs = hf_tokenizer([prompts], return_tensors="pt").to(model.device) + outputs_ids = model.generate(**model_inputs, max_new_tokens=256) + + outputs_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, outputs_ids) + ] + + response = hf_tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)[0] + print(f"response : {response}") + + +def verify_tokenizer_and_model(hf_tokenizer_path, model): + hf_tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path) + + _verify_tokenizer_and_model(hf_tokenizer, model) + + +def load_and_verify_hf_model(source_model): + model = AutoModelForCausalLM.from_pretrained( + source_model, torch_dtype="auto", device_map="auto" + ) + + verify_tokenizer_and_model(source_model, model) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--source_model", default=None, type=str, required=False, help="source model." + ) + parser.add_argument( + "--output_dir", default=None, type=str, required=False, help="Where to save the converted model." + ) + parser.add_argument( + "--get_scaled_weights", action="store_true", required=False, help="get scaled weights" + ) + args = parser.parse_args() + + if not args.output_dir: + if args.get_scaled_weights: + read_mxfp4_list(args.source_model) + else: + load_and_verify_hf_model(args.source_model) + else: + quantize(args.source_model, args.output_dir, ref_weights_scale_inv_map_path="/root/models/gpt-oss-120b") diff --git a/scripts/gpt_oss_triton_mxfp4.py b/scripts/gpt_oss_triton_mxfp4.py new file mode 100644 index 0000000..71b8251 --- /dev/null +++ b/scripts/gpt_oss_triton_mxfp4.py @@ -0,0 +1,9 @@ +import torch +import triton + +import triton_kernels +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp + +def quantize_bf16_mxfp4(w, block_size=None): + w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=-1) + return w, w_scale \ No newline at end of file diff --git a/scripts/simple_py_mxfp4.py b/scripts/simple_py_mxfp4.py new file mode 100644 index 0000000..d1782bd --- /dev/null +++ b/scripts/simple_py_mxfp4.py @@ -0,0 +1,124 @@ +import torch + + +# the functions are adapted from https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/0bea1c31d75761002aad4290e572cf7c512d8b3a/modelopt/torch/quantization/qtensor/mxfp4_tensor.py#L25 + +E2M1_max = 6.0 + +E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6] + +# TODO (yiakwy) : create from E2M1_values +FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + +E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) + +def pack_uint4x2_to_uint8(x): + # If the last dimension is odd, pad with zeros + # If this behavior is not desired, please modify the code accordingly + left_side = x[..., 0::2] # Even indices (0, 2, 4...) + right_side = x[..., 1::2] # Odd indices (1, 3, 5...) + new_data = right_side.clone() << 4 # Put odd indices (higher addresses) in high bits + new_data[..., : left_side.shape[-1]] += left_side # Put even indices in low bits + return new_data + +def cast_fp4(x): + sign = torch.sign(x) + sign_bit = (2 - sign) // 2 + ord_ = torch.sum( + (x.abs().unsqueeze(-1) - E2M1_bounds.to(x.device)) > 0, dim=-1 + ) + fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8) + return fp4_val + +# convert bf16 tensor to uint8 +def quantize_bf16_mxfp4(input : torch.Tensor, block_size : int | None): + block_size = block_size or 32 + + input = input.view(-1, block_size) + + input_amax = input.abs().max(dim=-1, keepdim=True).values + descale = input_amax / E2M1_max + + min_value = torch.tensor(-127.0, device=descale.device) + e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value)) + + original_shape = input.shape + input = (input / torch.exp2(e8m0_scale)).view(original_shape) + input_q = cast_fp4(input) + input_q = pack_uint4x2_to_uint8(input_q) + + e8m0_scale = (e8m0_scale + 127).to(torch.uint8) + return input_q, e8m0_scale + + +# the function is adapted from GPT_OSS repo +def convert_fp4_bf16( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + import math + + # Check if blocks and scales are on CPU, and move to GPU if so + if not blocks.is_cuda and torch.cuda.is_available(): + blocks = blocks.cuda() + scales = scales.cuda() + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp, sub + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + + # TODO: Delete after making sure this is not necessary! since we go back to cpu in the end in create_quantized_param using .to(target_device) + # Move back to CPU if needed + # if need_to_move_back: + # out = out.cpu() + del blocks, scales, lut + return out