Skip to content

Commit 413bfab

Browse files
Add GPT OSS MaxText to vLLM mappings and helper functions.
1 parent 094b41d commit 413bfab

File tree

3 files changed

+231
-1
lines changed

3 files changed

+231
-1
lines changed

src/MaxText/integration/tunix/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Utils for Tunix integration."""
1616

17+
import inspect
1718
import re
1819

20+
1921
import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
2022
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
2123
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
@@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
127129
def to_hf_mapping(self):
128130
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
129131
if self.use_standalone_mappings:
130-
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
132+
mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping
133+
total_num_layers = self.config["num_hidden_layers"]
134+
print(f"total_num_layers: {total_num_layers} for model: {self.model_name}")
135+
sig = inspect.signature(mapping_fn)
136+
if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters:
137+
mapping = mapping_fn(
138+
total_num_layers=total_num_layers,
139+
)
140+
return mapping
141+
142+
return mapping_fn()
131143

132144
config = self.config
133145
mapping = self.convert_hf_map_to_sharding_map(

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
model name. This allows for easy extension to support new models.
2020
"""
2121

22+
from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping
2223
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,8 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("gpt-oss"):
36+
return GptOssMaxTextMapping
3437
else:
3538
raise ValueError(f"{name} vLLM weight mapping not found.")
3639

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format.
16+
"""
17+
18+
from dataclasses import dataclass
19+
import logging
20+
from typing import Dict, Optional, Tuple
21+
import jax
22+
23+
24+
@dataclass
25+
class GptOssMaxTextMapping:
26+
"""
27+
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.
28+
29+
Supports:
30+
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
31+
"""
32+
33+
@staticmethod
34+
def lora_to_hf_mappings():
35+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
36+
37+
Returns:
38+
None, as LoRA mappings are not defined for this model.
39+
"""
40+
return None
41+
42+
@staticmethod
43+
def to_hf_hook_fns():
44+
"""Returns hook functions to fuse interleaved weights."""
45+
46+
def fuse_interleaved_gate(val, tgt_param):
47+
"""Fuse Gate (wi_0) with Multi-Host Sharding Support."""
48+
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param
49+
50+
# Safety Check
51+
if current.shape[-1] != val.shape[-1] * 2:
52+
if current.shape[-1] == val.shape[-1]:
53+
logging.debug("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
54+
return val
55+
logging.warning("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
56+
57+
# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
58+
# # MULTI-HOST case.
59+
val = jax.device_put(val, current.sharding)
60+
val.block_until_ready()
61+
62+
logging.debug("Hook: Interleaving Gate -> Even columns")
63+
return current.at[..., 0::2].set(val)
64+
65+
def fuse_interleaved_up(val, tgt_param):
66+
"""Fuse Up (wi_1) with Multi-Host Sharding Support."""
67+
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param
68+
69+
if current.shape[-1] != val.shape[-1] * 2:
70+
if current.shape[-1] == val.shape[-1]:
71+
logging.debug("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
72+
return val
73+
logging.warning("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
74+
75+
# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
76+
# # MULTI-HOST case.
77+
val = jax.device_put(val, current.sharding)
78+
val.block_until_ready()
79+
80+
logging.debug("Hook: Interleaving Up -> Odd columns")
81+
return current.at[..., 1::2].set(val)
82+
83+
return {
84+
r".*GptOssMlp\.wi_0.*": fuse_interleaved_gate,
85+
r".*GptOssMlp\.wi_1.*": fuse_interleaved_up,
86+
}
87+
88+
@staticmethod
89+
def to_hf_transpose_keys():
90+
"""Returns keys that need to be transposed."""
91+
return {}
92+
93+
@staticmethod
94+
def to_hf_mapping(
95+
layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo"
96+
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
97+
"""Returns the weight mapping for the model.
98+
99+
Args:
100+
layer_cycle_interval: The interval at which layers are cycled.
101+
total_num_layers: The total number of layers in the model.
102+
interleave_style: The style of interleaving used for the layers.
103+
104+
Returns:
105+
A dictionary mapping MaxText parameter names to vLLM parameter names.
106+
"""
107+
108+
mapping = {}
109+
110+
# --- 1. Global Parameters ---
111+
mapping.update(
112+
{
113+
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)),
114+
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
115+
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")),
116+
}
117+
)
118+
119+
# --- 2. Layer Mapping Loop ---
120+
layers_per_block = total_num_layers // layer_cycle_interval
121+
122+
for block_idx in range(layer_cycle_interval):
123+
src_block = f"base.decoder.layers.layers_{block_idx}"
124+
if interleave_style == "modulo":
125+
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
126+
else:
127+
start = block_idx * layers_per_block
128+
target_indices = range(start, start + layers_per_block)
129+
130+
regex_indices = "|".join(map(str, target_indices))
131+
layer_regex = f"layers\.({regex_indices})"
132+
133+
# --- 3. Block Mappings (Standard) ---
134+
mapping.update(
135+
{
136+
f"{src_block}.pre_self_attention_layer_norm.scale": (
137+
f"{layer_regex}.pre_attention_norm.scale",
138+
(None, "layer"),
139+
),
140+
f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")),
141+
f"{src_block}.GptOssAttention.query.kernel": (
142+
f"{layer_regex}.attn.kernel_q_DNH",
143+
(None, "layer", "model", None),
144+
),
145+
f"{src_block}.GptOssAttention.key.kernel": (
146+
f"{layer_regex}.attn.kernel_k_DKH",
147+
(None, "layer", "model", None),
148+
),
149+
f"{src_block}.GptOssAttention.value.kernel": (
150+
f"{layer_regex}.attn.kernel_v_DKH",
151+
(None, "layer", "model", None),
152+
),
153+
f"{src_block}.GptOssAttention.out.kernel": (
154+
f"{layer_regex}.attn.kernel_o_proj_NHD",
155+
("model", "layer", None, None),
156+
),
157+
f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)),
158+
f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)),
159+
f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)),
160+
f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")),
161+
f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")),
162+
}
163+
)
164+
165+
# MoE Router
166+
mapping.update(
167+
{
168+
f"{src_block}.GptOssMlp.gate.kernel": (
169+
f"{layer_regex}.custom_module.router.kernel_DE",
170+
(None, "layer", "model"),
171+
),
172+
f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")),
173+
}
174+
)
175+
176+
# --- MOE EXPERTS ---
177+
178+
# MLP1 BIASES
179+
mapping.update(
180+
{
181+
f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")),
182+
f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")),
183+
}
184+
)
185+
186+
# MLP1 WEIGHTS (Split -> Fused)
187+
mapping.update(
188+
{
189+
f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.mlp1_weight_EDF2", ("model", "layer", None)),
190+
f"{src_block}.GptOssMlp.wi_1": (
191+
f"{layer_regex}.custom_module.mlp1_weight_EDF2",
192+
# Original: (None, "layer", "expert", "model", None)
193+
("model", "layer", None),
194+
),
195+
}
196+
)
197+
198+
# MLP2 (Down Projection)
199+
mapping.update(
200+
{
201+
f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")),
202+
f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)),
203+
}
204+
)
205+
206+
# --- 4. Additional Config ---
207+
mapping.update(
208+
{
209+
"additional_config": {
210+
"layer_cycle_interval": layer_cycle_interval,
211+
}
212+
}
213+
)
214+
215+
return mapping

0 commit comments

Comments
 (0)