1919
2020from etils import epath
2121from flax import nnx
22+ import flax .linen as nn
2223from jax .sharding import Mesh
2324from MaxText import model_creation_utils
25+ from MaxText import max_logging
2426from MaxText import pyconfig
2527from MaxText .common_types import MODEL_MODE_AUTOREGRESSIVE
2628from MaxText .globals import MAXTEXT_PKG_DIR
@@ -106,6 +108,20 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N
106108 self .model : nnx .Module | None = None
107109 self .logits : jax .Array | None = None
108110
111+ # Handle dummy weight loading during initialization
112+ if vllm_config .load_config .load_format == "dummy" :
113+ if self .maxtext_config .load_parameters_path is not None :
114+ max_logging .log (
115+ "Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
116+ )
117+ self .maxtext_config .load_parameters_path = None
118+
119+ with self .mesh :
120+ self .load_weights (rng_key )
121+
122+ elif self .maxtext_config .load_parameters_path is None :
123+ max_logging .log ("Warning: No load_parameters_path provided. The model will be initialized with random weights." )
124+
109125 def __call__ (
110126 self ,
111127 kv_caches : list [jax .Array ],
@@ -142,16 +158,17 @@ def __call__(
142158 if input_positions .ndim < 2 :
143159 input_positions = jnp .expand_dims (input_positions , axis = 0 )
144160
145- # Store any auxiliary hidden states that may be required by specific models
146- aux_hidden_states = []
147- logits , hidden , kv_caches = self .model (
148- decoder_input_tokens = input_ids ,
149- decoder_positions = input_positions ,
150- kv_caches = kv_caches ,
151- attention_metadata = attention_metadata ,
152- model_mode = self .model_mode ,
153- ** kwargs ,
154- )
161+ with nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
162+ aux_hidden_states = []
163+ logits , hidden , kv_caches = self .model (
164+ decoder_input_tokens = input_ids ,
165+ decoder_positions = input_positions ,
166+ kv_caches = kv_caches ,
167+ attention_metadata = attention_metadata ,
168+ model_mode = self .model_mode ,
169+ ** kwargs ,
170+ )
171+
155172 if hidden .ndim > 1 :
156173 hidden = jnp .squeeze (hidden , axis = 0 )
157174 logits = jnp .squeeze (logits , axis = 0 )
@@ -172,8 +189,9 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
172189 if self .logits is not None :
173190 return self .logits
174191
175- embeddings = self .model .token_embedder
176- return self .model .decoder .apply_output_head (embeddings , hidden_states , True , self .model_mode )
192+ with nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
193+ embeddings = self .model .token_embedder
194+ return self .model .decoder .apply_output_head (embeddings , hidden_states , True , self .model_mode )
177195
178196 def load_weights (self , rng_key : jax .Array ) -> None :
179197 """Loads model parameters on the provided mesh.
@@ -226,7 +244,8 @@ def __call__(
226244 - hidden: The hidden states.
227245 - aux_hidden_states: A list of auxiliary hidden states.
228246 """
229- kv_caches , hidden , aux_hidden_states = self .model (kv_caches , input_ids , attention_metadata , * args , ** kwargs )
247+ with self .mesh :
248+ kv_caches , hidden , aux_hidden_states = self .model (kv_caches , input_ids , attention_metadata , * args , ** kwargs )
230249 return kv_caches , hidden , aux_hidden_states
231250
232251 def forward (self , * args , ** kwargs ):
@@ -247,7 +266,20 @@ def get_input_embeddings(self) -> jax.Array:
247266 Returns:
248267 A JAX array representing the input embeddings.
249268 """
250- return self .model .model .token_embedder .embedding
269+ with self .mesh :
270+ return self .model .model .token_embedder .embedding
271+
272+ def embed_input_ids (self , input_ids : jax .Array ) -> jax .Array :
273+ """Embeds the input token IDs using the model's token embedder.
274+
275+ Args:
276+ input_ids: A JAX array of input token IDs.
277+
278+ Returns:
279+ A JAX array of embedded input tokens.
280+ """
281+ with self .mesh :
282+ return self .model .model .token_embedder (input_ids )
251283
252284 def compute_logits (self , hidden_states : jax .Array ) -> jax .Array :
253285 """Computes the logits from the hidden states using the underlying decoder model.
@@ -258,12 +290,14 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
258290 Returns:
259291 A JAX array of logits.
260292 """
261- return self .model .compute_logits (hidden_states )
293+ with self .mesh :
294+ return self .model .compute_logits (hidden_states )
262295
263296 def load_weights (self , rng_key : jax .Array ) -> None :
264297 """Loads model weights using the underlying decoder model.
265298
266299 Args:
267300 rng_key: A JAX random key for model initialization.
268301 """
269- self .model .load_weights (rng_key )
302+ with self .mesh :
303+ self .model .load_weights (rng_key )
0 commit comments