|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format. |
| 16 | +""" |
| 17 | + |
| 18 | +from dataclasses import dataclass |
| 19 | +import logging |
| 20 | +from typing import Dict, Optional, Tuple |
| 21 | +import jax |
| 22 | + |
| 23 | + |
| 24 | +@dataclass |
| 25 | +class GptOssMaxTextMapping: |
| 26 | + """ |
| 27 | + Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. |
| 28 | +
|
| 29 | + Supports: |
| 30 | + - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...) |
| 31 | + """ |
| 32 | + |
| 33 | + @staticmethod |
| 34 | + def lora_to_hf_mappings(): |
| 35 | + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. |
| 36 | +
|
| 37 | + Returns: |
| 38 | + None, as LoRA mappings are not defined for this model. |
| 39 | + """ |
| 40 | + return None |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def to_hf_hook_fns(): |
| 44 | + """Returns hook functions to fuse interleaved weights.""" |
| 45 | + |
| 46 | + def fuse_interleaved_gate(val, tgt_param): |
| 47 | + """Fuse Gate (wi_0) with Multi-Host Sharding Support.""" |
| 48 | + current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param |
| 49 | + |
| 50 | + # Safety Check |
| 51 | + if current.shape[-1] != val.shape[-1] * 2: |
| 52 | + if current.shape[-1] == val.shape[-1]: |
| 53 | + logging.debug("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) |
| 54 | + return val |
| 55 | + logging.warning("Gate Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) |
| 56 | + |
| 57 | + # TODO: Enable multi-host sharding, if there is a mismatch in shapes. |
| 58 | + # # MULTI-HOST case. |
| 59 | + val = jax.device_put(val, current.sharding) |
| 60 | + val.block_until_ready() |
| 61 | + |
| 62 | + logging.debug("Hook: Interleaving Gate -> Even columns") |
| 63 | + return current.at[..., 0::2].set(val) |
| 64 | + |
| 65 | + def fuse_interleaved_up(val, tgt_param): |
| 66 | + """Fuse Up (wi_1) with Multi-Host Sharding Support.""" |
| 67 | + current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param |
| 68 | + |
| 69 | + if current.shape[-1] != val.shape[-1] * 2: |
| 70 | + if current.shape[-1] == val.shape[-1]: |
| 71 | + logging.debug("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) |
| 72 | + return val |
| 73 | + logging.warning("Up Fusion Shape Warning: Src %s -> Tgt %s", val.shape, current.shape) |
| 74 | + |
| 75 | + # TODO: Enable multi-host sharding, if there is a mismatch in shapes. |
| 76 | + # # MULTI-HOST case. |
| 77 | + val = jax.device_put(val, current.sharding) |
| 78 | + val.block_until_ready() |
| 79 | + |
| 80 | + logging.debug("Hook: Interleaving Up -> Odd columns") |
| 81 | + return current.at[..., 1::2].set(val) |
| 82 | + |
| 83 | + return { |
| 84 | + r".*GptOssMlp\.wi_0.*": fuse_interleaved_gate, |
| 85 | + r".*GptOssMlp\.wi_1.*": fuse_interleaved_up, |
| 86 | + } |
| 87 | + |
| 88 | + @staticmethod |
| 89 | + def to_hf_transpose_keys(): |
| 90 | + """Returns keys that need to be transposed.""" |
| 91 | + return {} |
| 92 | + |
| 93 | + @staticmethod |
| 94 | + def to_hf_mapping( |
| 95 | + layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo" |
| 96 | + ) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]: |
| 97 | + """Returns the weight mapping for the model. |
| 98 | +
|
| 99 | + Args: |
| 100 | + layer_cycle_interval: The interval at which layers are cycled. |
| 101 | + total_num_layers: The total number of layers in the model. |
| 102 | + interleave_style: The style of interleaving used for the layers. |
| 103 | +
|
| 104 | + Returns: |
| 105 | + A dictionary mapping MaxText parameter names to vLLM parameter names. |
| 106 | + """ |
| 107 | + |
| 108 | + mapping = {} |
| 109 | + |
| 110 | + # --- 1. Global Parameters --- |
| 111 | + mapping.update( |
| 112 | + { |
| 113 | + "base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)), |
| 114 | + "base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)), |
| 115 | + "base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")), |
| 116 | + } |
| 117 | + ) |
| 118 | + |
| 119 | + # --- 2. Layer Mapping Loop --- |
| 120 | + layers_per_block = total_num_layers // layer_cycle_interval |
| 121 | + |
| 122 | + for block_idx in range(layer_cycle_interval): |
| 123 | + src_block = f"base.decoder.layers.layers_{block_idx}" |
| 124 | + if interleave_style == "modulo": |
| 125 | + target_indices = range(block_idx, total_num_layers, layer_cycle_interval) |
| 126 | + else: |
| 127 | + start = block_idx * layers_per_block |
| 128 | + target_indices = range(start, start + layers_per_block) |
| 129 | + |
| 130 | + regex_indices = "|".join(map(str, target_indices)) |
| 131 | + layer_regex = f"layers\.({regex_indices})" |
| 132 | + |
| 133 | + # --- 3. Block Mappings (Standard) --- |
| 134 | + mapping.update( |
| 135 | + { |
| 136 | + f"{src_block}.pre_self_attention_layer_norm.scale": ( |
| 137 | + f"{layer_regex}.pre_attention_norm.scale", |
| 138 | + (None, "layer"), |
| 139 | + ), |
| 140 | + f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")), |
| 141 | + f"{src_block}.GptOssAttention.query.kernel": ( |
| 142 | + f"{layer_regex}.attn.kernel_q_DNH", |
| 143 | + (None, "layer", "model", None), |
| 144 | + ), |
| 145 | + f"{src_block}.GptOssAttention.key.kernel": ( |
| 146 | + f"{layer_regex}.attn.kernel_k_DKH", |
| 147 | + (None, "layer", "model", None), |
| 148 | + ), |
| 149 | + f"{src_block}.GptOssAttention.value.kernel": ( |
| 150 | + f"{layer_regex}.attn.kernel_v_DKH", |
| 151 | + (None, "layer", "model", None), |
| 152 | + ), |
| 153 | + f"{src_block}.GptOssAttention.out.kernel": ( |
| 154 | + f"{layer_regex}.attn.kernel_o_proj_NHD", |
| 155 | + ("model", "layer", None, None), |
| 156 | + ), |
| 157 | + f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)), |
| 158 | + f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)), |
| 159 | + f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)), |
| 160 | + f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")), |
| 161 | + f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")), |
| 162 | + } |
| 163 | + ) |
| 164 | + |
| 165 | + # MoE Router |
| 166 | + mapping.update( |
| 167 | + { |
| 168 | + f"{src_block}.GptOssMlp.gate.kernel": ( |
| 169 | + f"{layer_regex}.custom_module.router.kernel_DE", |
| 170 | + (None, "layer", "model"), |
| 171 | + ), |
| 172 | + f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")), |
| 173 | + } |
| 174 | + ) |
| 175 | + |
| 176 | + # --- MOE EXPERTS --- |
| 177 | + |
| 178 | + # MLP1 BIASES |
| 179 | + mapping.update( |
| 180 | + { |
| 181 | + f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")), |
| 182 | + f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.mlp1_bias_EF2", ("model", "layer")), |
| 183 | + } |
| 184 | + ) |
| 185 | + |
| 186 | + # MLP1 WEIGHTS (Split -> Fused) |
| 187 | + mapping.update( |
| 188 | + { |
| 189 | + f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.mlp1_weight_EDF2", ("model", "layer", None)), |
| 190 | + f"{src_block}.GptOssMlp.wi_1": ( |
| 191 | + f"{layer_regex}.custom_module.mlp1_weight_EDF2", |
| 192 | + # Original: (None, "layer", "expert", "model", None) |
| 193 | + ("model", "layer", None), |
| 194 | + ), |
| 195 | + } |
| 196 | + ) |
| 197 | + |
| 198 | + # MLP2 (Down Projection) |
| 199 | + mapping.update( |
| 200 | + { |
| 201 | + f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")), |
| 202 | + f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)), |
| 203 | + } |
| 204 | + ) |
| 205 | + |
| 206 | + # --- 4. Additional Config --- |
| 207 | + mapping.update( |
| 208 | + { |
| 209 | + "additional_config": { |
| 210 | + "layer_cycle_interval": layer_cycle_interval, |
| 211 | + } |
| 212 | + } |
| 213 | + ) |
| 214 | + |
| 215 | + return mapping |
0 commit comments