2424
2525from megatron .core import dist_checkpointing
2626from megatron .core .inference .communication_utils import broadcast_from_last_pipeline_stage
27+ from megatron .core .inference .contexts import StaticInferenceContext
2728from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import (
2829 GPTInferenceWrapper ,
2930)
@@ -348,15 +349,21 @@ def run_mcore_inference(
348349 active_hidden_size = model .decoder .layers [0 ].mlp .linear_fc1 .input_size
349350 else :
350351 raise ValueError (f"Cannot infer hidden size from { type (model .decoder .layers [0 ])= } " )
352+
351353 inference_wrapper_config = InferenceWrapperConfig (
352354 hidden_size = active_hidden_size ,
353355 inference_batch_times_seqlen_threshold = batch_size * model .max_sequence_length ,
354356 fp32_residual_connection = False ,
355357 params_dtype = torch .bfloat16 if model .config .bf16 else torch .float32 ,
356358 padded_vocab_size = model .vocab_size ,
357359 )
358- wrapped_model = GPTInferenceWrapper (model , inference_wrapper_config )
359- wrapped_model .prep_model_for_inference (prompt_tokens )
360+ # Get full sequence output instead of only last token logits
361+ inference_context = StaticInferenceContext .from_config (inference_wrapper_config )
362+ inference_context .materialize_only_last_token_logits = False
363+
364+ wrapped_model = GPTInferenceWrapper (model , inference_wrapper_config , inference_context )
365+ wrapped_model .prep_model_for_inference ()
366+
360367 inference_input = wrapped_model .prep_inference_input (prompt_tokens )
361368 inference_input = wrapped_model .get_batch_for_context_window (
362369 inference_input , 0 , model .max_sequence_length
@@ -375,7 +382,7 @@ def run_mcore_inference(
375382def run_mcore_inference_with_dummy_input (
376383 model : GPTModel | MambaModel , batch_size : int = 2 , hidden_size : int | None = None
377384) -> torch .Tensor :
378- """Run inference on a wrapped Megatron GPT or Mamba model."""
385+ """Run inference on a Megatron GPT or Mamba model with random dummy input ."""
379386 prompt_tokens = torch .randint (
380387 0 , model .vocab_size , (batch_size , model .max_sequence_length )
381388 ).cuda ()
0 commit comments