24
24
25
25
from megatron .core import dist_checkpointing
26
26
from megatron .core .inference .communication_utils import broadcast_from_last_pipeline_stage
27
+ from megatron .core .inference .contexts import StaticInferenceContext
27
28
from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import (
28
29
GPTInferenceWrapper ,
29
30
)
@@ -348,15 +349,21 @@ def run_mcore_inference(
348
349
active_hidden_size = model .decoder .layers [0 ].mlp .linear_fc1 .input_size
349
350
else :
350
351
raise ValueError (f"Cannot infer hidden size from { type (model .decoder .layers [0 ])= } " )
352
+
351
353
inference_wrapper_config = InferenceWrapperConfig (
352
354
hidden_size = active_hidden_size ,
353
355
inference_batch_times_seqlen_threshold = batch_size * model .max_sequence_length ,
354
356
fp32_residual_connection = False ,
355
357
params_dtype = torch .bfloat16 if model .config .bf16 else torch .float32 ,
356
358
padded_vocab_size = model .vocab_size ,
357
359
)
358
- wrapped_model = GPTInferenceWrapper (model , inference_wrapper_config )
359
- wrapped_model .prep_model_for_inference (prompt_tokens )
360
+ # Get full sequence output instead of only last token logits
361
+ inference_context = StaticInferenceContext .from_config (inference_wrapper_config )
362
+ inference_context .materialize_only_last_token_logits = False
363
+
364
+ wrapped_model = GPTInferenceWrapper (model , inference_wrapper_config , inference_context )
365
+ wrapped_model .prep_model_for_inference ()
366
+
360
367
inference_input = wrapped_model .prep_inference_input (prompt_tokens )
361
368
inference_input = wrapped_model .get_batch_for_context_window (
362
369
inference_input , 0 , model .max_sequence_length
@@ -375,7 +382,7 @@ def run_mcore_inference(
375
382
def run_mcore_inference_with_dummy_input (
376
383
model : GPTModel | MambaModel , batch_size : int = 2 , hidden_size : int | None = None
377
384
) -> torch .Tensor :
378
- """Run inference on a wrapped Megatron GPT or Mamba model."""
385
+ """Run inference on a Megatron GPT or Mamba model with random dummy input ."""
379
386
prompt_tokens = torch .randint (
380
387
0 , model .vocab_size , (batch_size , model .max_sequence_length )
381
388
).cuda ()
0 commit comments