Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from megatron.core import dist_checkpointing
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
Expand Down Expand Up @@ -348,15 +349,21 @@ def run_mcore_inference(
active_hidden_size = model.decoder.layers[0].mlp.linear_fc1.input_size
else:
raise ValueError(f"Cannot infer hidden size from {type(model.decoder.layers[0])=}")

inference_wrapper_config = InferenceWrapperConfig(
hidden_size=active_hidden_size,
inference_batch_times_seqlen_threshold=batch_size * model.max_sequence_length,
fp32_residual_connection=False,
params_dtype=torch.bfloat16 if model.config.bf16 else torch.float32,
padded_vocab_size=model.vocab_size,
)
wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
wrapped_model.prep_model_for_inference(prompt_tokens)
# Get full sequence output instead of only last token logits
inference_context = StaticInferenceContext.from_config(inference_wrapper_config)
inference_context.materialize_only_last_token_logits = False

wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config, inference_context)
wrapped_model.prep_model_for_inference()

inference_input = wrapped_model.prep_inference_input(prompt_tokens)
inference_input = wrapped_model.get_batch_for_context_window(
inference_input, 0, model.max_sequence_length
Expand All @@ -375,7 +382,7 @@ def run_mcore_inference(
def run_mcore_inference_with_dummy_input(
model: GPTModel | MambaModel, batch_size: int = 2, hidden_size: int | None = None
) -> torch.Tensor:
"""Run inference on a wrapped Megatron GPT or Mamba model."""
"""Run inference on a Megatron GPT or Mamba model with random dummy input."""
prompt_tokens = torch.randint(
0, model.vocab_size, (batch_size, model.max_sequence_length)
).cuda()
Expand Down
6 changes: 0 additions & 6 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@
auto_quantize_helper,
tensor_parallel_test_helper,
)
from packaging.version import Version

skip_if_no_megatron()

import megatron.core
from megatron.core.parallel_state import (
destroy_model_parallel,
get_data_parallel_group,
Expand Down Expand Up @@ -256,10 +254,6 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device)
mixed_block_size_config,
],
)
@pytest.mark.skipif(
Version(megatron.core.__version__) <= Version("0.13.0rc1"),
reason="This unittest need megatron.core>=0.13 with default heterogenous ckpt support.",
)
def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config):
spawn_multiprocess_job(
size=2,
Expand Down