Skip to content

Commit aca5b24

Browse files
Merge pull request #2612 from AI-Hypercomputer:nicogrande/maxtext-for-causal-lm
PiperOrigin-RevId: 833591289
2 parents 9ab4248 + aca8d7e commit aca5b24

File tree

10 files changed

+442
-19
lines changed

10 files changed

+442
-19
lines changed

src/MaxText/configs/vllm.yml

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
base_config: "base.yml"
16+
attention: "vllm_rpa"
17+
# NNX required for vLLM integration
18+
enable_nnx: True
19+
# Avoid re-initializing JAX distributed system when using vLLM
20+
skip_jax_distributed_system: True
21+
# Scanned layers are not supported with vLLM integration
22+
scan_layers: False
23+
24+
25+
# -------------- Logical Axis Rules --------------
26+
mesh_axes: ['data', 'model', 'expert']
27+
logical_axis_rules: [
28+
['activation_batch', ['data', 'expert']],
29+
['activation_batch_no_exp', ['data']],
30+
['activation_embed_and_logits_batch', ['data', 'expert']],
31+
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
32+
['activation_heads', ['model']],
33+
['activation_kv_heads', ['model']],
34+
['activation_length', ['expert']],
35+
['activation_q_length', ['expert']],
36+
['activation_embed', ['model']],
37+
['activation_mlp', ['model']],
38+
['activation_kv', ['model']],
39+
['activation_prefill_kv_batch', ['data', 'expert']],
40+
['activation_kv_batch', ['data', 'expert']],
41+
['activation_kv_batch_no_exp', ['data']],
42+
['activation_kv_head_dim', ['model']],
43+
['activation_vocab', ['model']],
44+
['activation_exp', ['expert']],
45+
['decode_batch', ['data', 'expert']],
46+
['mlp', ['model']],
47+
['mlp_no_fsdp', ['model']],
48+
['vocab', ['model']],
49+
['heads', ['model']],
50+
['q_heads', ['model']],
51+
['kv_heads', ['model']],
52+
['embed', ['expert']],
53+
['q_lora', ['expert']],
54+
['kv_lora', ['expert']],
55+
['norm', ['model']],
56+
['cache_heads', ['model']],
57+
['exp', ['expert']],
58+
['paged_kv_heads', ['model']],
59+
]
60+
data_sharding: [['data', 'model', 'expert']]
61+
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""MaxText vLLM adapter package."""
16+
17+
from tpu_inference.logger import init_logger
18+
from tpu_inference.models.common.model_loader import register_model
19+
from .adapter import MaxTextForCausalLM
20+
21+
22+
logger = init_logger(__name__)
23+
24+
25+
def register():
26+
"""Register MaxTextForCausalLM model with tpu_inference and vllm.
27+
28+
Note, this function is invoked directly by the vLLM engine during startup. As such,
29+
it leverages vLLM logging to report its status.
30+
"""
31+
logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.")
32+
register_model("MaxTextForCausalLM", MaxTextForCausalLM)
33+
logger.info("Successfully registered MaxTextForCausalLM model.")
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""vLLM adapter for MaxText models."""
16+
17+
import jax
18+
import jax.numpy as jnp
19+
20+
from etils import epath
21+
from flax import nnx
22+
from jax.sharding import Mesh
23+
from MaxText import model_creation_utils
24+
from MaxText import pyconfig
25+
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
26+
from MaxText.globals import MAXTEXT_PKG_DIR
27+
28+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29+
from vllm.config import VllmConfig
30+
31+
32+
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
33+
"""Generates a MaxText configuration from a vLLM configuration.
34+
35+
This function takes a vLLM configuration object and translates relevant
36+
parameters into a MaxText `HyperParameters` object. It handles loading
37+
paths and model names from the vLLM config, and applies a base MaxText
38+
vLLM configuration file.
39+
40+
Args:
41+
vllm_config: The vLLM configuration object containing model and load
42+
parameters.
43+
44+
Returns:
45+
A `pyconfig.HyperParameters` object configured for MaxText.
46+
47+
Raises:
48+
ValueError: If `hf_config_path` is not provided in the vLLM model config.
49+
"""
50+
51+
def _path_exists(path: str) -> bool:
52+
if not path:
53+
return False
54+
return epath.Path(path).exists()
55+
56+
if "maxtext_config" in vllm_config.additional_config:
57+
overrides = vllm_config.additional_config["maxtext_config"]
58+
else:
59+
overrides = {}
60+
load_path = None
61+
if _path_exists(vllm_config.load.download_dir):
62+
load_path = vllm_config.load.download_dir
63+
elif _path_exists(vllm_config.model.model):
64+
load_path = vllm_config.model.model
65+
66+
if load_path:
67+
overrides["load_parameters_path"] = load_path
68+
elif vllm_config.model.model:
69+
overrides["model_name"] = vllm_config.model.model
70+
71+
if vllm_config.model_config.hf_config_path is None:
72+
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
73+
74+
# Add base config path to positional args
75+
base_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
76+
argv_list = ["", str(base_config_path)]
77+
78+
maxtext_config = pyconfig.initialize(argv_list, **overrides)
79+
return maxtext_config
80+
81+
82+
class MaxTextDecoderModel(nnx.Module):
83+
"""A vLLM-compatible decoder model wrapper for MaxText.
84+
85+
This class adapts a MaxText model for use within the vLLM framework,
86+
handling configuration generation, model initialization, and execution
87+
of the decoding step.
88+
"""
89+
90+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None:
91+
"""Initializes the MaxTextDecoderModel.
92+
93+
Args:
94+
vllm_config: The vLLM configuration object.
95+
rng_key: A JAX random key for model initialization.
96+
mesh: The JAX mesh device for model sharding.
97+
"""
98+
self.vllm_config = vllm_config
99+
self.maxtext_config = generate_maxtext_config(vllm_config)
100+
101+
# Model configuration
102+
self.mesh = mesh
103+
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
104+
105+
# Model creation
106+
self.model: nnx.Module | None = None
107+
self.logits: jax.Array | None = None
108+
109+
def __call__(
110+
self,
111+
kv_caches: list[jax.Array],
112+
input_ids: jax.Array,
113+
attention_metadata: AttentionMetadata,
114+
*args,
115+
**kwargs,
116+
) -> tuple[list[jax.Array], jax.Array, list[jax.Array]]:
117+
"""Performs a forward pass through the decoder model.
118+
119+
Args:
120+
kv_caches: A list of JAX arrays representing the KV caches.
121+
input_ids: A JAX array of input token IDs.
122+
attention_metadata: Attention metadata for the decoding process.
123+
*args: Variable length argument list.
124+
**kwargs: Arbitrary keyword arguments.
125+
126+
Returns:
127+
A tuple containing:
128+
- updated_kv_caches: A list of updated KV caches.
129+
- hidden: The hidden states (Q, d_model).
130+
- aux_hidden_states: A list of auxiliary hidden states.
131+
132+
Raises:
133+
ValueError: If the model is not an instance of `nnx.Module`.
134+
"""
135+
if not isinstance(self.model, nnx.Module):
136+
raise ValueError("Model must be an instance of type nnx.Module.")
137+
138+
if input_ids.ndim < 2:
139+
input_ids = jnp.expand_dims(input_ids, axis=0)
140+
141+
input_positions = attention_metadata.input_positions
142+
if input_positions.ndim < 2:
143+
input_positions = jnp.expand_dims(input_positions, axis=0)
144+
145+
# Store any auxiliary hidden states that may be required by specific models
146+
aux_hidden_states = []
147+
logits, hidden, kv_caches = self.model(
148+
decoder_input_tokens=input_ids,
149+
decoder_positions=input_positions,
150+
kv_caches=kv_caches,
151+
attention_metadata=attention_metadata,
152+
model_mode=self.model_mode,
153+
**kwargs,
154+
)
155+
if hidden.ndim > 1:
156+
hidden = jnp.squeeze(hidden, axis=0)
157+
logits = jnp.squeeze(logits, axis=0)
158+
159+
self.logits = logits # cache logits for compute_logits call
160+
161+
return kv_caches, hidden, aux_hidden_states
162+
163+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
164+
"""Computes the logits from the hidden states.
165+
166+
Args:
167+
hidden_states: A JAX array of hidden states.
168+
169+
Returns:
170+
A JAX array of logits (Q, vocab_size).
171+
"""
172+
if self.logits is not None:
173+
return self.logits
174+
175+
embeddings = self.model.token_embedder
176+
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
177+
178+
def load_weights(self, rng_key: jax.Array) -> None:
179+
"""Loads model parameters on the provided mesh.
180+
181+
Args:
182+
rng_key: A JAX random key for model initialization.
183+
"""
184+
self.model, _ = model_creation_utils.create_nnx_model(
185+
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
186+
)
187+
188+
189+
class MaxTextForCausalLM(nnx.Module):
190+
"""A vLLM-compatible causal language model wrapper for MaxText.
191+
192+
This class serves as the primary interface for integrating MaxText models
193+
into the vLLM serving framework, specifically for causal language modeling
194+
tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected
195+
by vLLM.
196+
"""
197+
198+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
199+
"""Initializes the MaxTextForCausalLM model.
200+
201+
Args:
202+
vllm_config: The vLLM configuration object.
203+
rng_key: A JAX random key for model initialization.
204+
mesh: The JAX mesh device for model sharding.
205+
"""
206+
self.cfg = vllm_config.model_config
207+
self.mesh = mesh
208+
self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh)
209+
self.is_text_generation_model = True
210+
211+
def __call__(
212+
self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs
213+
) -> tuple[list[jax.Array], jax.Array]:
214+
"""Performs a forward pass through the causal language model.
215+
216+
Args:
217+
kv_caches: A list of JAX arrays representing the KV caches.
218+
input_ids: A JAX array of input token IDs.
219+
attention_metadata: Attention metadata for the decoding process.
220+
*args: Variable length argument list.
221+
**kwargs: Arbitrary keyword arguments.
222+
223+
Returns:
224+
A tuple containing:
225+
- updated_kv_caches: A list of updated KV caches.
226+
- hidden: The hidden states.
227+
- aux_hidden_states: A list of auxiliary hidden states.
228+
"""
229+
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
230+
return kv_caches, hidden, aux_hidden_states
231+
232+
def forward(self, *args, **kwargs):
233+
"""Alias for __call__ for compatibility.
234+
235+
Args:
236+
*args: Variable length argument list.
237+
**kwargs: Arbitrary keyword arguments.
238+
239+
Returns:
240+
The result of the `__call__` method.
241+
"""
242+
return self(*args, **kwargs)
243+
244+
def get_input_embeddings(self) -> jax.Array:
245+
"""Returns the input embeddings of the model.
246+
247+
Returns:
248+
A JAX array representing the input embeddings.
249+
"""
250+
return self.model.model.token_embedder.embedding
251+
252+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
253+
"""Computes the logits from the hidden states using the underlying decoder model.
254+
255+
Args:
256+
hidden_states: A JAX array of hidden states.
257+
258+
Returns:
259+
A JAX array of logits.
260+
"""
261+
return self.model.compute_logits(hidden_states)
262+
263+
def load_weights(self, rng_key: jax.Array) -> None:
264+
"""Loads model weights using the underlying decoder model.
265+
266+
Args:
267+
rng_key: A JAX random key for model initialization.
268+
"""
269+
self.model.load_weights(rng_key)

0 commit comments

Comments
 (0)