Skip to content

Commit aca8d7e

Browse files
committed
Adding MaxTextForCausalLM interface.
adding vllm.yml updating valid attention kernels removing protected access from output head call. reading paths with epath. fixing issues related to rngs and adding ep. remove duplicate logical axis rules. adding pyconfig deprecated validation. updating return type adapter. adding docstring to register.
1 parent 62b249a commit aca8d7e

File tree

10 files changed

+442
-19
lines changed

10 files changed

+442
-19
lines changed

src/MaxText/configs/vllm.yml

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
base_config: "base.yml"
16+
attention: "vllm_rpa"
17+
# NNX required for vLLM integration
18+
enable_nnx: True
19+
# Avoid re-initializing JAX distributed system when using vLLM
20+
skip_jax_distributed_system: True
21+
# Scanned layers are not supported with vLLM integration
22+
scan_layers: False
23+
24+
25+
# -------------- Logical Axis Rules --------------
26+
mesh_axes: ['data', 'model', 'expert']
27+
logical_axis_rules: [
28+
['activation_batch', ['data', 'expert']],
29+
['activation_batch_no_exp', ['data']],
30+
['activation_embed_and_logits_batch', ['data', 'expert']],
31+
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
32+
['activation_heads', ['model']],
33+
['activation_kv_heads', ['model']],
34+
['activation_length', ['expert']],
35+
['activation_q_length', ['expert']],
36+
['activation_embed', ['model']],
37+
['activation_mlp', ['model']],
38+
['activation_kv', ['model']],
39+
['activation_prefill_kv_batch', ['data', 'expert']],
40+
['activation_kv_batch', ['data', 'expert']],
41+
['activation_kv_batch_no_exp', ['data']],
42+
['activation_kv_head_dim', ['model']],
43+
['activation_vocab', ['model']],
44+
['activation_exp', ['expert']],
45+
['decode_batch', ['data', 'expert']],
46+
['mlp', ['model']],
47+
['mlp_no_fsdp', ['model']],
48+
['vocab', ['model']],
49+
['heads', ['model']],
50+
['q_heads', ['model']],
51+
['kv_heads', ['model']],
52+
['embed', ['expert']],
53+
['q_lora', ['expert']],
54+
['kv_lora', ['expert']],
55+
['norm', ['model']],
56+
['cache_heads', ['model']],
57+
['exp', ['expert']],
58+
['paged_kv_heads', ['model']],
59+
]
60+
data_sharding: [['data', 'model', 'expert']]
61+
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""MaxText vLLM adapter package."""
16+
17+
from tpu_inference.logger import init_logger
18+
from tpu_inference.models.common.model_loader import register_model
19+
from .adapter import MaxTextForCausalLM
20+
21+
22+
logger = init_logger(__name__)
23+
24+
25+
def register():
26+
"""Register MaxTextForCausalLM model with tpu_inference and vllm.
27+
28+
Note, this function is invoked directly by the vLLM engine during startup. As such,
29+
it leverages vLLM logging to report its status.
30+
"""
31+
logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.")
32+
register_model("MaxTextForCausalLM", MaxTextForCausalLM)
33+
logger.info("Successfully registered MaxTextForCausalLM model.")
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""vLLM adapter for MaxText models."""
16+
17+
import jax
18+
import jax.numpy as jnp
19+
20+
from etils import epath
21+
from flax import nnx
22+
from jax.sharding import Mesh
23+
from MaxText import model_creation_utils
24+
from MaxText import pyconfig
25+
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
26+
from MaxText.globals import MAXTEXT_PKG_DIR
27+
28+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29+
from vllm.config import VllmConfig
30+
31+
32+
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
33+
"""Generates a MaxText configuration from a vLLM configuration.
34+
35+
This function takes a vLLM configuration object and translates relevant
36+
parameters into a MaxText `HyperParameters` object. It handles loading
37+
paths and model names from the vLLM config, and applies a base MaxText
38+
vLLM configuration file.
39+
40+
Args:
41+
vllm_config: The vLLM configuration object containing model and load
42+
parameters.
43+
44+
Returns:
45+
A `pyconfig.HyperParameters` object configured for MaxText.
46+
47+
Raises:
48+
ValueError: If `hf_config_path` is not provided in the vLLM model config.
49+
"""
50+
51+
def _path_exists(path: str) -> bool:
52+
if not path:
53+
return False
54+
return epath.Path(path).exists()
55+
56+
if "maxtext_config" in vllm_config.additional_config:
57+
overrides = vllm_config.additional_config["maxtext_config"]
58+
else:
59+
overrides = {}
60+
load_path = None
61+
if _path_exists(vllm_config.load.download_dir):
62+
load_path = vllm_config.load.download_dir
63+
elif _path_exists(vllm_config.model.model):
64+
load_path = vllm_config.model.model
65+
66+
if load_path:
67+
overrides["load_parameters_path"] = load_path
68+
elif vllm_config.model.model:
69+
overrides["model_name"] = vllm_config.model.model
70+
71+
if vllm_config.model_config.hf_config_path is None:
72+
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
73+
74+
# Add base config path to positional args
75+
base_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
76+
argv_list = ["", str(base_config_path)]
77+
78+
maxtext_config = pyconfig.initialize(argv_list, **overrides)
79+
return maxtext_config
80+
81+
82+
class MaxTextDecoderModel(nnx.Module):
83+
"""A vLLM-compatible decoder model wrapper for MaxText.
84+
85+
This class adapts a MaxText model for use within the vLLM framework,
86+
handling configuration generation, model initialization, and execution
87+
of the decoding step.
88+
"""
89+
90+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None:
91+
"""Initializes the MaxTextDecoderModel.
92+
93+
Args:
94+
vllm_config: The vLLM configuration object.
95+
rng_key: A JAX random key for model initialization.
96+
mesh: The JAX mesh device for model sharding.
97+
"""
98+
self.vllm_config = vllm_config
99+
self.maxtext_config = generate_maxtext_config(vllm_config)
100+
101+
# Model configuration
102+
self.mesh = mesh
103+
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
104+
105+
# Model creation
106+
self.model: nnx.Module | None = None
107+
self.logits: jax.Array | None = None
108+
109+
def __call__(
110+
self,
111+
kv_caches: list[jax.Array],
112+
input_ids: jax.Array,
113+
attention_metadata: AttentionMetadata,
114+
*args,
115+
**kwargs,
116+
) -> tuple[list[jax.Array], jax.Array, list[jax.Array]]:
117+
"""Performs a forward pass through the decoder model.
118+
119+
Args:
120+
kv_caches: A list of JAX arrays representing the KV caches.
121+
input_ids: A JAX array of input token IDs.
122+
attention_metadata: Attention metadata for the decoding process.
123+
*args: Variable length argument list.
124+
**kwargs: Arbitrary keyword arguments.
125+
126+
Returns:
127+
A tuple containing:
128+
- updated_kv_caches: A list of updated KV caches.
129+
- hidden: The hidden states (Q, d_model).
130+
- aux_hidden_states: A list of auxiliary hidden states.
131+
132+
Raises:
133+
ValueError: If the model is not an instance of `nnx.Module`.
134+
"""
135+
if not isinstance(self.model, nnx.Module):
136+
raise ValueError("Model must be an instance of type nnx.Module.")
137+
138+
if input_ids.ndim < 2:
139+
input_ids = jnp.expand_dims(input_ids, axis=0)
140+
141+
input_positions = attention_metadata.input_positions
142+
if input_positions.ndim < 2:
143+
input_positions = jnp.expand_dims(input_positions, axis=0)
144+
145+
# Store any auxiliary hidden states that may be required by specific models
146+
aux_hidden_states = []
147+
logits, hidden, kv_caches = self.model(
148+
decoder_input_tokens=input_ids,
149+
decoder_positions=input_positions,
150+
kv_caches=kv_caches,
151+
attention_metadata=attention_metadata,
152+
model_mode=self.model_mode,
153+
**kwargs,
154+
)
155+
if hidden.ndim > 1:
156+
hidden = jnp.squeeze(hidden, axis=0)
157+
logits = jnp.squeeze(logits, axis=0)
158+
159+
self.logits = logits # cache logits for compute_logits call
160+
161+
return kv_caches, hidden, aux_hidden_states
162+
163+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
164+
"""Computes the logits from the hidden states.
165+
166+
Args:
167+
hidden_states: A JAX array of hidden states.
168+
169+
Returns:
170+
A JAX array of logits (Q, vocab_size).
171+
"""
172+
if self.logits is not None:
173+
return self.logits
174+
175+
embeddings = self.model.token_embedder
176+
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
177+
178+
def load_weights(self, rng_key: jax.Array) -> None:
179+
"""Loads model parameters on the provided mesh.
180+
181+
Args:
182+
rng_key: A JAX random key for model initialization.
183+
"""
184+
self.model, _ = model_creation_utils.create_nnx_model(
185+
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
186+
)
187+
188+
189+
class MaxTextForCausalLM(nnx.Module):
190+
"""A vLLM-compatible causal language model wrapper for MaxText.
191+
192+
This class serves as the primary interface for integrating MaxText models
193+
into the vLLM serving framework, specifically for causal language modeling
194+
tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected
195+
by vLLM.
196+
"""
197+
198+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
199+
"""Initializes the MaxTextForCausalLM model.
200+
201+
Args:
202+
vllm_config: The vLLM configuration object.
203+
rng_key: A JAX random key for model initialization.
204+
mesh: The JAX mesh device for model sharding.
205+
"""
206+
self.cfg = vllm_config.model_config
207+
self.mesh = mesh
208+
self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh)
209+
self.is_text_generation_model = True
210+
211+
def __call__(
212+
self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs
213+
) -> tuple[list[jax.Array], jax.Array]:
214+
"""Performs a forward pass through the causal language model.
215+
216+
Args:
217+
kv_caches: A list of JAX arrays representing the KV caches.
218+
input_ids: A JAX array of input token IDs.
219+
attention_metadata: Attention metadata for the decoding process.
220+
*args: Variable length argument list.
221+
**kwargs: Arbitrary keyword arguments.
222+
223+
Returns:
224+
A tuple containing:
225+
- updated_kv_caches: A list of updated KV caches.
226+
- hidden: The hidden states.
227+
- aux_hidden_states: A list of auxiliary hidden states.
228+
"""
229+
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
230+
return kv_caches, hidden, aux_hidden_states
231+
232+
def forward(self, *args, **kwargs):
233+
"""Alias for __call__ for compatibility.
234+
235+
Args:
236+
*args: Variable length argument list.
237+
**kwargs: Arbitrary keyword arguments.
238+
239+
Returns:
240+
The result of the `__call__` method.
241+
"""
242+
return self(*args, **kwargs)
243+
244+
def get_input_embeddings(self) -> jax.Array:
245+
"""Returns the input embeddings of the model.
246+
247+
Returns:
248+
A JAX array representing the input embeddings.
249+
"""
250+
return self.model.model.token_embedder.embedding
251+
252+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
253+
"""Computes the logits from the hidden states using the underlying decoder model.
254+
255+
Args:
256+
hidden_states: A JAX array of hidden states.
257+
258+
Returns:
259+
A JAX array of logits.
260+
"""
261+
return self.model.compute_logits(hidden_states)
262+
263+
def load_weights(self, rng_key: jax.Array) -> None:
264+
"""Loads model weights using the underlying decoder model.
265+
266+
Args:
267+
rng_key: A JAX random key for model initialization.
268+
"""
269+
self.model.load_weights(rng_key)

0 commit comments

Comments
 (0)