Skip to content

Commit 69d6317

Browse files
Add GPT OSS MaxText to vLLM mappings and helper functions.
1 parent 094b41d commit 69d6317

File tree

3 files changed

+220
-1
lines changed

3 files changed

+220
-1
lines changed

src/MaxText/integration/tunix/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Utils for Tunix integration."""
1616

17+
import inspect
1718
import re
1819

20+
1921
import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
2022
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
2123
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
@@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
127129
def to_hf_mapping(self):
128130
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
129131
if self.use_standalone_mappings:
130-
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
132+
mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping
133+
total_num_layers = self.config["num_hidden_layers"]
134+
print(f"total_num_layers: {total_num_layers} for model: {self.model_name}")
135+
sig = inspect.signature(mapping_fn)
136+
if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters:
137+
mapping = mapping_fn(
138+
total_num_layers=total_num_layers,
139+
)
140+
return mapping
141+
142+
return mapping_fn()
131143

132144
config = self.config
133145
mapping = self.convert_hf_map_to_sharding_map(

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
model name. This allows for easy extension to support new models.
2020
"""
2121

22+
from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping
2223
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,8 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("gpt-oss"):
36+
return GptOssMaxTextMapping
3437
else:
3538
raise ValueError(f"{name} vLLM weight mapping not found.")
3639

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format.
16+
"""
17+
18+
from dataclasses import dataclass
19+
import logging
20+
from typing import Dict, Optional, Tuple
21+
import jax
22+
23+
24+
@dataclass
25+
class GptOssMaxTextMapping:
26+
"""
27+
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.
28+
29+
Supports:
30+
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
31+
"""
32+
33+
@staticmethod
34+
def lora_to_hf_mappings():
35+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
36+
37+
Returns:
38+
None, as LoRA mappings are not defined for this model.
39+
"""
40+
return None
41+
42+
@staticmethod
43+
def to_hf_hook_fns():
44+
"""Returns hook functions to fuse interleaved weights."""
45+
def fuse_interleaved_gate(val, tgt_param):
46+
"""Fuse Gate (wi_0) with Multi-Host Sharding Support."""
47+
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param
48+
49+
# Safety Check
50+
if current.shape[-1] != val.shape[-1] * 2:
51+
if current.shape[-1] == val.shape[-1]:
52+
logging.debug("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
53+
return val
54+
logging.warning("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
55+
56+
# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
57+
# # MULTI-HOST case.
58+
val = jax.device_put(val, current.sharding)
59+
val.block_until_ready()
60+
61+
logging.debug("Hook: Interleaving Gate -> Even columns")
62+
return current.at[..., 0::2].set(val)
63+
64+
def fuse_interleaved_up(val, tgt_param):
65+
"""Fuse Up (wi_1) with Multi-Host Sharding Support."""
66+
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param
67+
68+
if current.shape[-1] != val.shape[-1] * 2:
69+
if current.shape[-1] == val.shape[-1]:
70+
logging.debug("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
71+
return val
72+
logging.warning("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape)
73+
74+
# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
75+
# # MULTI-HOST case.
76+
val = jax.device_put(val, current.sharding)
77+
val.block_until_ready()
78+
79+
logging.debug("Hook: Interleaving Up -> Odd columns")
80+
return current.at[..., 1::2].set(val)
81+
82+
return {
83+
r".*GptOssMlp\.wi_0.*": fuse_interleaved_gate,
84+
r".*GptOssMlp\.wi_1.*": fuse_interleaved_up,
85+
}
86+
87+
@staticmethod
88+
def to_hf_transpose_keys():
89+
"""Returns keys that need to be transposed."""
90+
return {}
91+
92+
@staticmethod
93+
def to_hf_mapping(
94+
layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo"
95+
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
96+
97+
mapping = {}
98+
99+
# --- 1. Global Parameters ---
100+
mapping.update(
101+
{
102+
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)),
103+
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
104+
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")),
105+
}
106+
)
107+
108+
# --- 2. Layer Mapping Loop ---
109+
layers_per_block = total_num_layers // layer_cycle_interval
110+
111+
for block_idx in range(layer_cycle_interval):
112+
src_block = f"base.decoder.layers.layers_{block_idx}"
113+
if interleave_style == "modulo":
114+
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
115+
else:
116+
start = block_idx * layers_per_block
117+
target_indices = range(start, start + layers_per_block)
118+
119+
regex_indices = "|".join(map(str, target_indices))
120+
layer_regex = f"layers\.({regex_indices})"
121+
122+
# --- 3. Block Mappings (Standard) ---
123+
mapping.update(
124+
{
125+
f"{src_block}.pre_self_attention_layer_norm.scale": (
126+
f"{layer_regex}.pre_attention_norm.scale",
127+
(None, "layer"),
128+
),
129+
f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")),
130+
f"{src_block}.GptOssAttention.query.kernel": (
131+
f"{layer_regex}.attn.kernel_q_DNH",
132+
(None, "layer", "model", None),
133+
),
134+
f"{src_block}.GptOssAttention.key.kernel": (
135+
f"{layer_regex}.attn.kernel_k_DKH",
136+
(None, "layer", "model", None),
137+
),
138+
f"{src_block}.GptOssAttention.value.kernel": (
139+
f"{layer_regex}.attn.kernel_v_DKH",
140+
(None, "layer", "model", None),
141+
),
142+
f"{src_block}.GptOssAttention.out.kernel": (
143+
f"{layer_regex}.attn.kernel_o_proj_NHD",
144+
("model", "layer", None, None),
145+
),
146+
f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)),
147+
f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)),
148+
f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)),
149+
f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")),
150+
f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")),
151+
}
152+
)
153+
154+
# MoE Router
155+
mapping.update(
156+
{
157+
f"{src_block}.GptOssMlp.gate.kernel": (
158+
f"{layer_regex}.custom_module.router.kernel_DE",
159+
(None, "layer", "model"),
160+
),
161+
f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")),
162+
}
163+
)
164+
165+
# --- MOE EXPERTS ---
166+
167+
# MLP1 BIASES
168+
mapping.update(
169+
{
170+
f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")),
171+
f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")),
172+
}
173+
)
174+
175+
# MLP1 WEIGHTS (Split -> Fused)
176+
mapping.update(
177+
{
178+
f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.mlp1_weight_EDF2", ("model", "layer", None)),
179+
f"{src_block}.GptOssMlp.wi_1": (
180+
f"{layer_regex}.custom_module.mlp1_weight_EDF2",
181+
# Original: (None, "layer", "expert", "model", None)
182+
("model", "layer", None),
183+
),
184+
}
185+
)
186+
187+
# MLP2 (Down Projection)
188+
mapping.update(
189+
{
190+
f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")),
191+
f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)),
192+
}
193+
)
194+
195+
# --- 4. Additional Config ---
196+
mapping.update(
197+
{
198+
"additional_config": {
199+
"layer_cycle_interval": layer_cycle_interval,
200+
}
201+
}
202+
)
203+
204+
return mapping

0 commit comments

Comments
 (0)