diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 9c1dd1bf7..99d5715ee 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -24,6 +24,7 @@ from megatron.core import dist_checkpointing from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage +from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) @@ -348,6 +349,7 @@ def run_mcore_inference( active_hidden_size = model.decoder.layers[0].mlp.linear_fc1.input_size else: raise ValueError(f"Cannot infer hidden size from {type(model.decoder.layers[0])=}") + inference_wrapper_config = InferenceWrapperConfig( hidden_size=active_hidden_size, inference_batch_times_seqlen_threshold=batch_size * model.max_sequence_length, @@ -355,8 +357,13 @@ def run_mcore_inference( params_dtype=torch.bfloat16 if model.config.bf16 else torch.float32, padded_vocab_size=model.vocab_size, ) - wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config) - wrapped_model.prep_model_for_inference(prompt_tokens) + # Get full sequence output instead of only last token logits + inference_context = StaticInferenceContext.from_config(inference_wrapper_config) + inference_context.materialize_only_last_token_logits = False + + wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config, inference_context) + wrapped_model.prep_model_for_inference() + inference_input = wrapped_model.prep_inference_input(prompt_tokens) inference_input = wrapped_model.get_batch_for_context_window( inference_input, 0, model.max_sequence_length @@ -375,7 +382,7 @@ def run_mcore_inference( def run_mcore_inference_with_dummy_input( model: GPTModel | MambaModel, batch_size: int = 2, hidden_size: int | None = None ) -> torch.Tensor: - """Run inference on a wrapped Megatron GPT or Mamba model.""" + """Run inference on a Megatron GPT or Mamba model with random dummy input.""" prompt_tokens = torch.randint( 0, model.vocab_size, (batch_size, model.max_sequence_length) ).cuda() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c3630e028..226403ea2 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -33,11 +33,9 @@ auto_quantize_helper, tensor_parallel_test_helper, ) -from packaging.version import Version skip_if_no_megatron() -import megatron.core from megatron.core.parallel_state import ( destroy_model_parallel, get_data_parallel_group, @@ -256,10 +254,6 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) mixed_block_size_config, ], ) -@pytest.mark.skipif( - Version(megatron.core.__version__) <= Version("0.13.0rc1"), - reason="This unittest need megatron.core>=0.13 with default heterogenous ckpt support.", -) def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config): spawn_multiprocess_job( size=2,