Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/MaxText/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

"""Utils for Tunix integration."""

import inspect
import re


import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
Expand Down Expand Up @@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
def to_hf_mapping(self):
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping
total_num_layers = self.config["num_hidden_layers"]
print(f"total_num_layers: {total_num_layers} for model: {self.model_name}")
sig = inspect.signature(mapping_fn)
if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters:
mapping = mapping_fn(
total_num_layers=total_num_layers,
)
return mapping

return mapping_fn()

config = self.config
mapping = self.convert_hf_map_to_sharding_map(
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
model name. This allows for easy extension to support new models.
"""

from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING

Expand All @@ -31,6 +32,8 @@ def __getattr__(self, name):
return LLAMA3_VLLM_MAPPING
elif name.startswith("qwen3"):
return QWEN3_VLLM_MAPPING
elif name.startswith("gpt-oss"):
return GptOssMaxTextMapping
else:
raise ValueError(f"{name} vLLM weight mapping not found.")

Expand Down
215 changes: 215 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format.
"""

from dataclasses import dataclass
import logging
from typing import Dict, Optional, Tuple
import jax


@dataclass
class GptOssMaxTextMapping:
"""
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.

Supports:
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
"""

@staticmethod
def lora_to_hf_mappings():
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.

Returns:
None, as LoRA mappings are not defined for this model.
"""
return None

@staticmethod
def to_hf_hook_fns():
"""Returns hook functions to fuse interleaved weights."""

def fuse_interleaved_gate(val, tgt_param):
"""Fuse Gate (wi_0) with Multi-Host Sharding Support."""
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param

# Safety Check
if current.shape[-1] != val.shape[-1] * 2:
if current.shape[-1] == val.shape[-1]:
logging.debug("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
return val
logging.warning("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)

# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
# # MULTI-HOST case.
val = jax.device_put(val, current.sharding)
val.block_until_ready()

logging.debug("Hook: Interleaving Gate -> Even columns")
return current.at[..., 0::2].set(val)

def fuse_interleaved_up(val, tgt_param):
"""Fuse Up (wi_1) with Multi-Host Sharding Support."""
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param

if current.shape[-1] != val.shape[-1] * 2:
if current.shape[-1] == val.shape[-1]:
logging.debug("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
return val
logging.warning("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)

# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
# # MULTI-HOST case.
val = jax.device_put(val, current.sharding)
val.block_until_ready()

logging.debug("Hook: Interleaving Up -> Odd columns")
return current.at[..., 1::2].set(val)

return {
r".*GptOssMlp\.wi_0.*": fuse_interleaved_gate,
r".*GptOssMlp\.wi_1.*": fuse_interleaved_up,
}

@staticmethod
def to_hf_transpose_keys():
"""Returns keys that need to be transposed."""
return {}

@staticmethod
def to_hf_mapping(
layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo"
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
"""Returns the weight mapping for the model.

Args:
layer_cycle_interval: The interval at which layers are cycled.
total_num_layers: The total number of layers in the model.
interleave_style: The style of interleaving used for the layers.

Returns:
A dictionary mapping MaxText parameter names to vLLM parameter names.
"""

mapping = {}

# --- 1. Global Parameters ---
mapping.update(
{
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)),
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")),
}
)

# --- 2. Layer Mapping Loop ---
layers_per_block = total_num_layers // layer_cycle_interval

for block_idx in range(layer_cycle_interval):
src_block = f"base.decoder.layers.layers_{block_idx}"
if interleave_style == "modulo":
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
else:
start = block_idx * layers_per_block
target_indices = range(start, start + layers_per_block)

regex_indices = "|".join(map(str, target_indices))
layer_regex = f"layers\.({regex_indices})"

# --- 3. Block Mappings (Standard) ---
mapping.update(
{
f"{src_block}.pre_self_attention_layer_norm.scale": (
f"{layer_regex}.pre_attention_norm.scale",
(None, "layer"),
),
f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")),
f"{src_block}.GptOssAttention.query.kernel": (
f"{layer_regex}.attn.kernel_q_DNH",
(None, "layer", "model", None),
),
f"{src_block}.GptOssAttention.key.kernel": (
f"{layer_regex}.attn.kernel_k_DKH",
(None, "layer", "model", None),
),
f"{src_block}.GptOssAttention.value.kernel": (
f"{layer_regex}.attn.kernel_v_DKH",
(None, "layer", "model", None),
),
f"{src_block}.GptOssAttention.out.kernel": (
f"{layer_regex}.attn.kernel_o_proj_NHD",
("model", "layer", None, None),
),
f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)),
f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)),
f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)),
f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")),
f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")),
}
)

# MoE Router
mapping.update(
{
f"{src_block}.GptOssMlp.gate.kernel": (
f"{layer_regex}.custom_module.router.kernel_DE",
(None, "layer", "model"),
),
f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")),
}
)

# --- MOE EXPERTS ---

# MLP1 BIASES
mapping.update(
{
f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")),
f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")),
}
)

# MLP1 WEIGHTS (Split -> Fused)
mapping.update(
{
f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.mlp1_weight_EDF2", ("model", "layer", None)),
f"{src_block}.GptOssMlp.wi_1": (
f"{layer_regex}.custom_module.mlp1_weight_EDF2",
# Original: (None, "layer", "expert", "model", None)
("model", "layer", None),
),
}
)

# MLP2 (Down Projection)
mapping.update(
{
f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")),
f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)),
}
)

# --- 4. Additional Config ---
mapping.update(
{
"additional_config": {
"layer_cycle_interval": layer_cycle_interval,
}
}
)

return mapping
Loading