Skip to content

Commit dbd887d

Browse files
Add GPT OSS vllm mapping generator.
1 parent 9204d6b commit dbd887d

File tree

3 files changed

+221
-1
lines changed

3 files changed

+221
-1
lines changed

src/MaxText/integration/tunix/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Utils for Tunix integration."""
1616

17+
import inspect
1718
import re
1819

20+
1921
import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
2022
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
2123
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
@@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
127129
def to_hf_mapping(self):
128130
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
129131
if self.use_standalone_mappings:
130-
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
132+
mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping
133+
total_num_layers = self.config["num_hidden_layers"]
134+
print(f"total_num_layers: {total_num_layers} for model: {self.model_name}")
135+
sig = inspect.signature(mapping_fn)
136+
if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters:
137+
mapping = mapping_fn(
138+
total_num_layers=total_num_layers,
139+
)
140+
return mapping
141+
142+
return mapping_fn()
131143

132144
config = self.config
133145
mapping = self.convert_hf_map_to_sharding_map(

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
model name. This allows for easy extension to support new models.
2020
"""
2121

22+
from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping
2223
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,8 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("gpt"):
36+
return GptOssMaxTextMapping
3437
else:
3538
raise ValueError(f"{name} vLLM weight mapping not found.")
3639

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format.
16+
"""
17+
18+
from dataclasses import dataclass
19+
import logging
20+
from typing import Dict, Optional, Tuple
21+
22+
23+
@dataclass
24+
class GptOssMaxTextMapping:
25+
"""
26+
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.
27+
28+
Supports:
29+
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
30+
"""
31+
@staticmethod
32+
def lora_to_hf_mappings():
33+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
34+
35+
Returns:
36+
None, as LoRA mappings are not defined for this model.
37+
"""
38+
return None
39+
40+
@staticmethod
41+
def to_hf_hook_fns():
42+
def fuse_interleaved_gate(val, tgt_param):
43+
"""Fuse Gate (wi_0) with Multi-Host Sharding Support."""
44+
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param
45+
46+
# Safety Check
47+
if current.shape[-1] != val.shape[-1] * 2:
48+
if current.shape[-1] == val.shape[-1]:
49+
logging.debug(f"Gate Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}")
50+
return val
51+
logging.warning(f"Gate Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}")
52+
53+
# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
54+
# # MULTI-HOST case.
55+
# val = jax.device_put(val, current.sharding)
56+
# val.block_until_ready()
57+
58+
logging.debug("Hook: Interleaving Gate -> Even columns")
59+
return current.at[..., 0::2].set(val)
60+
61+
def fuse_interleaved_up(val, tgt_param):
62+
"""Fuse Up (wi_1) with Multi-Host Sharding Support."""
63+
current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param
64+
65+
if current.shape[-1] != val.shape[-1] * 2:
66+
if current.shape[-1] == val.shape[-1]:
67+
logging.debug(f"Up Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}")
68+
return val
69+
logging.warning(f"Up Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}")
70+
71+
# TODO: Enable multi-host sharding, if there is a mismatch in shapes.
72+
# # MULTI-HOST case.
73+
# val = jax.device_put(val, current.sharding)
74+
# val.block_until_ready()
75+
76+
logging.debug("Hook: Interleaving Up -> Odd columns")
77+
return current.at[..., 1::2].set(val)
78+
79+
return {
80+
r'.*GptOssMlp\.wi_0.*': fuse_interleaved_gate,
81+
r'.*GptOssMlp\.wi_1.*': fuse_interleaved_up,
82+
}
83+
84+
@staticmethod
85+
def to_hf_transpose_keys():
86+
return {}
87+
88+
@staticmethod
89+
def to_hf_mapping(
90+
layer_cycle_interval: int = 2,
91+
total_num_layers: int = 36,
92+
interleave_style: str = "modulo"
93+
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
94+
95+
mapping = {}
96+
97+
# --- 1. Global Parameters ---
98+
mapping.update({
99+
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", (("data", "model"), None)),
100+
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
101+
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, ("data", "model"))),
102+
})
103+
104+
# --- 2. Layer Mapping Loop ---
105+
layers_per_block = total_num_layers // layer_cycle_interval
106+
107+
for block_idx in range(layer_cycle_interval):
108+
src_block = f"base.decoder.layers.layers_{block_idx}"
109+
if interleave_style == "modulo":
110+
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
111+
else:
112+
start = block_idx * layers_per_block
113+
target_indices = range(start, start + layers_per_block)
114+
115+
regex_indices = "|".join(map(str, target_indices))
116+
layer_regex = f"layers\.({regex_indices})"
117+
118+
# --- 3. Block Mappings (Standard) ---
119+
mapping.update({
120+
f"{src_block}.pre_self_attention_layer_norm.scale":
121+
(f"{layer_regex}.pre_attention_norm.scale", (None, "layer")),
122+
f"{src_block}.post_self_attention_layer_norm.scale": (
123+
f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")
124+
),
125+
f"{src_block}.GptOssAttention.query.kernel": (
126+
f"{layer_regex}.attn.kernel_q_DNH",
127+
(None, "layer", "model", None)
128+
),
129+
f"{src_block}.GptOssAttention.key.kernel":
130+
(f"{layer_regex}.attn.kernel_k_DKH", (None, "layer", "model", None)),
131+
f"{src_block}.GptOssAttention.value.kernel":
132+
(f"{layer_regex}.attn.kernel_v_DKH", (None, "layer", "model", None)),
133+
f"{src_block}.GptOssAttention.out.kernel": (
134+
f"{layer_regex}.attn.kernel_o_proj_NHD",
135+
("model", "layer", None, None)
136+
),
137+
f"{src_block}.GptOssAttention.query.bias": (
138+
f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)
139+
),
140+
f"{src_block}.GptOssAttention.key.bias": (
141+
f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)
142+
),
143+
f"{src_block}.GptOssAttention.value.bias": (
144+
f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)
145+
),
146+
f"{src_block}.GptOssAttention.out.bias": (
147+
f"{layer_regex}.attn.bias_o_D", (None, "layer")
148+
),
149+
f"{src_block}.GptOssAttention.sinks": (
150+
f"{layer_regex}.attn.sinks_N", (None, "layer")
151+
),
152+
})
153+
154+
# MoE Router
155+
mapping.update({
156+
f"{src_block}.GptOssMlp.gate.kernel": (
157+
f"{layer_regex}.custom_module.router.kernel_DE",
158+
(None, "layer", "model")
159+
),
160+
f"{src_block}.GptOssMlp.gate.bias": (
161+
f"{layer_regex}.custom_module.router.bias_E",
162+
("model", "layer")
163+
),
164+
})
165+
166+
# --- MOE EXPERTS ---
167+
168+
# MLP1 BIASES
169+
mapping.update({
170+
f"{src_block}.GptOssMlp.wi_0_bias": (
171+
f"{layer_regex}.custom_module.mlp1_bias_EF2",
172+
("model", "layer")
173+
),
174+
f"{src_block}.GptOssMlp.wi_1_bias": (
175+
f"{layer_regex}.custom_module.mlp1_bias_EF2",
176+
("model", "layer")
177+
),
178+
})
179+
180+
# MLP1 WEIGHTS (Split -> Fused)
181+
mapping.update({
182+
f"{src_block}.GptOssMlp.wi_0": (
183+
f"{layer_regex}.custom_module.mlp1_weight_EDF2",
184+
("model", "layer", None)
185+
),
186+
f"{src_block}.GptOssMlp.wi_1": (
187+
f"{layer_regex}.custom_module.mlp1_weight_EDF2",
188+
# Original: (None, "layer", "expert", "model", None)
189+
("model", "layer", None)
190+
),
191+
})
192+
193+
# MLP2 (Down Projection)
194+
mapping.update({
195+
f"{src_block}.GptOssMlp.wo_bias": (
196+
f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")
197+
),
198+
199+
f"{src_block}.GptOssMlp.wo": (
200+
f"{layer_regex}.custom_module.mlp2_weight_EFD",
201+
("model", "layer", None)
202+
),
203+
})
204+
205+
return mapping

0 commit comments

Comments
 (0)