Skip to content

Commit 1f0080a

Browse files
Fix gpu test mcore inference utility for Mcore 0.14+
Signed-off-by: Keval Morabia <[email protected]>
1 parent 340eb7a commit 1f0080a

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from megatron.core import dist_checkpointing
2626
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
27+
from megatron.core.inference.contexts import StaticInferenceContext
2728
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
2829
GPTInferenceWrapper,
2930
)
@@ -348,15 +349,21 @@ def run_mcore_inference(
348349
active_hidden_size = model.decoder.layers[0].mlp.linear_fc1.input_size
349350
else:
350351
raise ValueError(f"Cannot infer hidden size from {type(model.decoder.layers[0])=}")
352+
351353
inference_wrapper_config = InferenceWrapperConfig(
352354
hidden_size=active_hidden_size,
353355
inference_batch_times_seqlen_threshold=batch_size * model.max_sequence_length,
354356
fp32_residual_connection=False,
355357
params_dtype=torch.bfloat16 if model.config.bf16 else torch.float32,
356358
padded_vocab_size=model.vocab_size,
357359
)
358-
wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
359-
wrapped_model.prep_model_for_inference(prompt_tokens)
360+
# Get full sequence output instead of only last token logits
361+
inference_context = StaticInferenceContext.from_config(inference_wrapper_config)
362+
inference_context.materialize_only_last_token_logits = False
363+
364+
wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config, inference_context)
365+
wrapped_model.prep_model_for_inference()
366+
360367
inference_input = wrapped_model.prep_inference_input(prompt_tokens)
361368
inference_input = wrapped_model.get_batch_for_context_window(
362369
inference_input, 0, model.max_sequence_length
@@ -375,7 +382,7 @@ def run_mcore_inference(
375382
def run_mcore_inference_with_dummy_input(
376383
model: GPTModel | MambaModel, batch_size: int = 2, hidden_size: int | None = None
377384
) -> torch.Tensor:
378-
"""Run inference on a wrapped Megatron GPT or Mamba model."""
385+
"""Run inference on a Megatron GPT or Mamba model with random dummy input."""
379386
prompt_tokens = torch.randint(
380387
0, model.vocab_size, (batch_size, model.max_sequence_length)
381388
).cuda()

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@
3333
auto_quantize_helper,
3434
tensor_parallel_test_helper,
3535
)
36-
from packaging.version import Version
3736

3837
skip_if_no_megatron()
3938

40-
import megatron.core
4139
from megatron.core.parallel_state import (
4240
destroy_model_parallel,
4341
get_data_parallel_group,
@@ -256,10 +254,6 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device)
256254
mixed_block_size_config,
257255
],
258256
)
259-
@pytest.mark.skipif(
260-
Version(megatron.core.__version__) <= Version("0.13.0rc1"),
261-
reason="This unittest need megatron.core>=0.13 with default heterogenous ckpt support.",
262-
)
263257
def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config):
264258
spawn_multiprocess_job(
265259
size=2,

0 commit comments

Comments
 (0)