Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/speculative_decoding/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
--logging_steps 100 \
--tf32 True \
--data_path $DATA \
--report_to tensorboard \
$SPECULATIVE_ARGS
"

Expand Down
19 changes: 18 additions & 1 deletion tests/examples/speculative_decoding/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json

from _test_utils.examples.run_command import run_example_command


# fmt: off
def test_llama_eagle(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
# Create an ultra-tiny EAGLE config for testing to reduce memory usage
tiny_eagle_config = {
"max_position_embeddings": 128,
"num_hidden_layers": 1,
"intermediate_size": 64,
"num_attention_heads": 2,
"num_key_value_heads": 2,
}

# Write the tiny config to a temporary file
config_file = tmp_path / "tiny_eagle_config.json"
with open(config_file, "w") as f:
json.dump(tiny_eagle_config, f)

run_example_command(
[
"./launch.sh",
Expand All @@ -29,7 +44,9 @@ def test_llama_eagle(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_p
"--do_eval", "False",
"--num_gpu", str(num_gpus),
"--mode", "eagle3",
"--eagle_config", str(config_file),
"--output_dir", tmp_path / "eagle-tinyllama",
"--training_seq_len", "128", # Match max_position_embeddings
],
"speculative_decoding",
)
6 changes: 0 additions & 6 deletions tests/gpu/torch/export/test_unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size):
],
)
def test_unified_export_megatron(tmp_path, model_type, arch, algo):
if algo == "eagle":
try:
import megatron.core.post_training # noqa: F401
except ImportError:
pytest.skip("megatron.core.post_training not found")

# TODO: Fix TP>1 failures
spawn_multiprocess_job(
size=1, # torch.cuda.device_count(),
Expand Down
Loading