Skip to content

Commit 868ab2c

Browse files
Merge pull request #2784 from AI-Hypercomputer:nicogrande/maxtext-vllm-sharding-fixes
PiperOrigin-RevId: 840889949
2 parents 2d5c026 + a32396d commit 868ab2c

File tree

3 files changed

+64
-28
lines changed

3 files changed

+64
-28
lines changed

src/MaxText/configs/vllm.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,24 @@ scan_layers: False
2525
# -------------- Logical Axis Rules --------------
2626
mesh_axes: ['data', 'model', 'expert']
2727
logical_axis_rules: [
28-
['activation_batch', ['data', 'expert']],
29-
['activation_batch_no_exp', ['data']],
30-
['activation_embed_and_logits_batch', ['data', 'expert']],
31-
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
28+
['activation_batch', ['expert']],
29+
['activation_batch_no_exp', []],
30+
['activation_embed_and_logits_batch', ['expert']],
31+
['activation_embed_and_logits_batch_sequence', ['expert']],
3232
['activation_heads', ['model']],
3333
['activation_kv_heads', ['model']],
34-
['activation_length', ['expert']],
35-
['activation_q_length', ['expert']],
34+
['activation_length', ['data', 'expert']],
35+
['activation_q_length', ['data', 'expert']],
3636
['activation_embed', ['model']],
3737
['activation_mlp', ['model']],
3838
['activation_kv', ['model']],
39-
['activation_prefill_kv_batch', ['data', 'expert']],
40-
['activation_kv_batch', ['data', 'expert']],
41-
['activation_kv_batch_no_exp', ['data']],
39+
['activation_prefill_kv_batch', ['expert']],
40+
['activation_kv_batch', ['expert']],
41+
['activation_kv_batch_no_exp', []],
4242
['activation_kv_head_dim', ['model']],
4343
['activation_vocab', ['model']],
4444
['activation_exp', ['expert']],
45-
['decode_batch', ['data', 'expert']],
45+
['decode_batch', ['expert']],
4646
['mlp', ['model']],
4747
['mlp_no_fsdp', ['model']],
4848
['vocab', ['model']],

src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
from etils import epath
2121
from flax import nnx
22+
import flax.linen as nn
2223
from jax.sharding import Mesh
2324
from MaxText import model_creation_utils
25+
from MaxText import max_logging
2426
from MaxText import pyconfig
2527
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
2628
from MaxText.globals import MAXTEXT_PKG_DIR
@@ -106,6 +108,20 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N
106108
self.model: nnx.Module | None = None
107109
self.logits: jax.Array | None = None
108110

111+
# Handle dummy weight loading during initialization
112+
if vllm_config.load_config.load_format == "dummy":
113+
if self.maxtext_config.load_parameters_path is not None:
114+
max_logging.log(
115+
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
116+
)
117+
self.maxtext_config.load_parameters_path = None
118+
119+
with self.mesh:
120+
self.load_weights(rng_key)
121+
122+
elif self.maxtext_config.load_parameters_path is None:
123+
max_logging.log("Warning: No load_parameters_path provided. The model will be initialized with random weights.")
124+
109125
def __call__(
110126
self,
111127
kv_caches: list[jax.Array],
@@ -142,16 +158,17 @@ def __call__(
142158
if input_positions.ndim < 2:
143159
input_positions = jnp.expand_dims(input_positions, axis=0)
144160

145-
# Store any auxiliary hidden states that may be required by specific models
146-
aux_hidden_states = []
147-
logits, hidden, kv_caches = self.model(
148-
decoder_input_tokens=input_ids,
149-
decoder_positions=input_positions,
150-
kv_caches=kv_caches,
151-
attention_metadata=attention_metadata,
152-
model_mode=self.model_mode,
153-
**kwargs,
154-
)
161+
with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
162+
aux_hidden_states = []
163+
logits, hidden, kv_caches = self.model(
164+
decoder_input_tokens=input_ids,
165+
decoder_positions=input_positions,
166+
kv_caches=kv_caches,
167+
attention_metadata=attention_metadata,
168+
model_mode=self.model_mode,
169+
**kwargs,
170+
)
171+
155172
if hidden.ndim > 1:
156173
hidden = jnp.squeeze(hidden, axis=0)
157174
logits = jnp.squeeze(logits, axis=0)
@@ -172,8 +189,9 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
172189
if self.logits is not None:
173190
return self.logits
174191

175-
embeddings = self.model.token_embedder
176-
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
192+
with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
193+
embeddings = self.model.token_embedder
194+
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
177195

178196
def load_weights(self, rng_key: jax.Array) -> None:
179197
"""Loads model parameters on the provided mesh.
@@ -226,7 +244,8 @@ def __call__(
226244
- hidden: The hidden states.
227245
- aux_hidden_states: A list of auxiliary hidden states.
228246
"""
229-
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
247+
with self.mesh:
248+
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
230249
return kv_caches, hidden, aux_hidden_states
231250

232251
def forward(self, *args, **kwargs):
@@ -247,7 +266,20 @@ def get_input_embeddings(self) -> jax.Array:
247266
Returns:
248267
A JAX array representing the input embeddings.
249268
"""
250-
return self.model.model.token_embedder.embedding
269+
with self.mesh:
270+
return self.model.model.token_embedder.embedding
271+
272+
def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
273+
"""Embeds the input token IDs using the model's token embedder.
274+
275+
Args:
276+
input_ids: A JAX array of input token IDs.
277+
278+
Returns:
279+
A JAX array of embedded input tokens.
280+
"""
281+
with self.mesh:
282+
return self.model.model.token_embedder(input_ids)
251283

252284
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
253285
"""Computes the logits from the hidden states using the underlying decoder model.
@@ -258,12 +290,14 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
258290
Returns:
259291
A JAX array of logits.
260292
"""
261-
return self.model.compute_logits(hidden_states)
293+
with self.mesh:
294+
return self.model.compute_logits(hidden_states)
262295

263296
def load_weights(self, rng_key: jax.Array) -> None:
264297
"""Loads model weights using the underlying decoder model.
265298
266299
Args:
267300
rng_key: A JAX random key for model initialization.
268301
"""
269-
self.model.load_weights(rng_key)
302+
with self.mesh:
303+
self.model.load_weights(rng_key)

src/MaxText/model_creation_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN,
133133

134134
_create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key)
135135

136-
abstract_model = nnx.eval_shape(_create_model_partial)
136+
with nn.logical_axis_rules(config.logical_axis_rules):
137+
abstract_model = nnx.eval_shape(_create_model_partial)
137138
graphdef, abstract_state = nnx.split(abstract_model)
138139
specs = nnx.get_partition_spec(abstract_state)
139140

@@ -155,7 +156,8 @@ def create_sharded_state():
155156

156157
with mesh:
157158
# Create the model with sharded parameters.
158-
sharded_state = create_sharded_state()
159+
with nn.logical_axis_rules(config.logical_axis_rules):
160+
sharded_state = create_sharded_state()
159161
model = nnx.merge(graphdef, sharded_state)
160162

161163
if config.load_parameters_path:

0 commit comments

Comments
 (0)