From 413bfab05287e541f86ef5542b1a020bcd513f72 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Wed, 10 Dec 2025 07:17:07 +0000 Subject: [PATCH] Add GPT OSS MaxText to vLLM mappings and helper functions. --- src/MaxText/integration/tunix/utils.py | 14 +- .../tunix/weight_mapping/__init__.py | 3 + .../tunix/weight_mapping/gpt_oss.py | 215 ++++++++++++++++++ 3 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/MaxText/integration/tunix/weight_mapping/gpt_oss.py diff --git a/src/MaxText/integration/tunix/utils.py b/src/MaxText/integration/tunix/utils.py index 2cf12c048..161202c52 100644 --- a/src/MaxText/integration/tunix/utils.py +++ b/src/MaxText/integration/tunix/utils.py @@ -14,8 +14,10 @@ """Utils for Tunix integration.""" +import inspect import re + import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS @@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False): def to_hf_mapping(self): """Returns a mapping from MaxText parameter names to HuggingFace parameter names.""" if self.use_standalone_mappings: - return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping() + mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping + total_num_layers = self.config["num_hidden_layers"] + print(f"total_num_layers: {total_num_layers} for model: {self.model_name}") + sig = inspect.signature(mapping_fn) + if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters: + mapping = mapping_fn( + total_num_layers=total_num_layers, + ) + return mapping + + return mapping_fn() config = self.config mapping = self.convert_hf_map_to_sharding_map( diff --git a/src/MaxText/integration/tunix/weight_mapping/__init__.py b/src/MaxText/integration/tunix/weight_mapping/__init__.py index d250ee2fe..3a38388c5 100644 --- a/src/MaxText/integration/tunix/weight_mapping/__init__.py +++ b/src/MaxText/integration/tunix/weight_mapping/__init__.py @@ -19,6 +19,7 @@ model name. This allows for easy extension to support new models. """ +from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING @@ -31,6 +32,8 @@ def __getattr__(self, name): return LLAMA3_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING + elif name.startswith("gpt-oss"): + return GptOssMaxTextMapping else: raise ValueError(f"{name} vLLM weight mapping not found.") diff --git a/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py b/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py new file mode 100644 index 000000000..bc6cbee8d --- /dev/null +++ b/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py @@ -0,0 +1,215 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format. +""" + +from dataclasses import dataclass +import logging +from typing import Dict, Optional, Tuple +import jax + + +@dataclass +class GptOssMaxTextMapping: + """ + Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. + + Supports: + - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...) + """ + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. + + Returns: + None, as LoRA mappings are not defined for this model. + """ + return None + + @staticmethod + def to_hf_hook_fns(): + """Returns hook functions to fuse interleaved weights.""" + + def fuse_interleaved_gate(val, tgt_param): + """Fuse Gate (wi_0) with Multi-Host Sharding Support.""" + current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param + + # Safety Check + if current.shape[-1] != val.shape[-1] * 2: + if current.shape[-1] == val.shape[-1]: + logging.debug("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) + return val + logging.warning("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) + + # TODO: Enable multi-host sharding, if there is a mismatch in shapes. + # # MULTI-HOST case. + val = jax.device_put(val, current.sharding) + val.block_until_ready() + + logging.debug("Hook: Interleaving Gate -> Even columns") + return current.at[..., 0::2].set(val) + + def fuse_interleaved_up(val, tgt_param): + """Fuse Up (wi_1) with Multi-Host Sharding Support.""" + current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param + + if current.shape[-1] != val.shape[-1] * 2: + if current.shape[-1] == val.shape[-1]: + logging.debug("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) + return val + logging.warning("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) + + # TODO: Enable multi-host sharding, if there is a mismatch in shapes. + # # MULTI-HOST case. + val = jax.device_put(val, current.sharding) + val.block_until_ready() + + logging.debug("Hook: Interleaving Up -> Odd columns") + return current.at[..., 1::2].set(val) + + return { + r".*GptOssMlp\.wi_0.*": fuse_interleaved_gate, + r".*GptOssMlp\.wi_1.*": fuse_interleaved_up, + } + + @staticmethod + def to_hf_transpose_keys(): + """Returns keys that need to be transposed.""" + return {} + + @staticmethod + def to_hf_mapping( + layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo" + ) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]: + """Returns the weight mapping for the model. + + Args: + layer_cycle_interval: The interval at which layers are cycled. + total_num_layers: The total number of layers in the model. + interleave_style: The style of interleaving used for the layers. + + Returns: + A dictionary mapping MaxText parameter names to vLLM parameter names. + """ + + mapping = {} + + # --- 1. Global Parameters --- + mapping.update( + { + "base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)), + "base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)), + "base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")), + } + ) + + # --- 2. Layer Mapping Loop --- + layers_per_block = total_num_layers // layer_cycle_interval + + for block_idx in range(layer_cycle_interval): + src_block = f"base.decoder.layers.layers_{block_idx}" + if interleave_style == "modulo": + target_indices = range(block_idx, total_num_layers, layer_cycle_interval) + else: + start = block_idx * layers_per_block + target_indices = range(start, start + layers_per_block) + + regex_indices = "|".join(map(str, target_indices)) + layer_regex = f"layers\.({regex_indices})" + + # --- 3. Block Mappings (Standard) --- + mapping.update( + { + f"{src_block}.pre_self_attention_layer_norm.scale": ( + f"{layer_regex}.pre_attention_norm.scale", + (None, "layer"), + ), + f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")), + f"{src_block}.GptOssAttention.query.kernel": ( + f"{layer_regex}.attn.kernel_q_DNH", + (None, "layer", "model", None), + ), + f"{src_block}.GptOssAttention.key.kernel": ( + f"{layer_regex}.attn.kernel_k_DKH", + (None, "layer", "model", None), + ), + f"{src_block}.GptOssAttention.value.kernel": ( + f"{layer_regex}.attn.kernel_v_DKH", + (None, "layer", "model", None), + ), + f"{src_block}.GptOssAttention.out.kernel": ( + f"{layer_regex}.attn.kernel_o_proj_NHD", + ("model", "layer", None, None), + ), + f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)), + f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)), + f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)), + f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")), + f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")), + } + ) + + # MoE Router + mapping.update( + { + f"{src_block}.GptOssMlp.gate.kernel": ( + f"{layer_regex}.custom_module.router.kernel_DE", + (None, "layer", "model"), + ), + f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")), + } + ) + + # --- MOE EXPERTS --- + + # MLP1 BIASES + mapping.update( + { + f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")), + f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")), + } + ) + + # MLP1 WEIGHTS (Split -> Fused) + mapping.update( + { + f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.mlp1_weight_EDF2", ("model", "layer", None)), + f"{src_block}.GptOssMlp.wi_1": ( + f"{layer_regex}.custom_module.mlp1_weight_EDF2", + # Original: (None, "layer", "expert", "model", None) + ("model", "layer", None), + ), + } + ) + + # MLP2 (Down Projection) + mapping.update( + { + f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")), + f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)), + } + ) + + # --- 4. Additional Config --- + mapping.update( + { + "additional_config": { + "layer_cycle_interval": layer_cycle_interval, + } + } + ) + + return mapping