Skip to content

Commit cf6f1d4

Browse files
Reduce eagle example test memory usage from 28 to 1 GB (#299)
Signed-off-by: Keval Morabia <[email protected]>
1 parent b233ad1 commit cf6f1d4

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

examples/speculative_decoding/launch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
150150
--logging_steps 100 \
151151
--tf32 True \
152152
--data_path $DATA \
153+
--report_to tensorboard \
153154
$SPECULATIVE_ARGS
154155
"
155156

tests/examples/speculative_decoding/test_eagle.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,27 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617

1718
from _test_utils.examples.run_command import run_example_command
1819

1920

2021
# fmt: off
21-
def test_llama_eagle(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
22+
def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
23+
# Create an ultra-tiny EAGLE config for testing to reduce memory usage
24+
tiny_eagle_config = {
25+
"max_position_embeddings": 128,
26+
"num_hidden_layers": 1,
27+
"intermediate_size": 64,
28+
"num_attention_heads": 2,
29+
"num_key_value_heads": 2,
30+
}
31+
32+
# Write the tiny config to a temporary file
33+
config_file = tmp_path / "tiny_eagle_config.json"
34+
with open(config_file, "w") as f:
35+
json.dump(tiny_eagle_config, f)
36+
2237
run_example_command(
2338
[
2439
"./launch.sh",
@@ -29,7 +44,9 @@ def test_llama_eagle(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_p
2944
"--do_eval", "False",
3045
"--num_gpu", str(num_gpus),
3146
"--mode", "eagle3",
47+
"--eagle_config", str(config_file),
3248
"--output_dir", tmp_path / "eagle-tinyllama",
49+
"--training_seq_len", "128", # Match max_position_embeddings
3350
],
3451
"speculative_decoding",
3552
)

tests/gpu/torch/export/test_unified_export_megatron.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,6 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size):
110110
],
111111
)
112112
def test_unified_export_megatron(tmp_path, model_type, arch, algo):
113-
if algo == "eagle":
114-
try:
115-
import megatron.core.post_training # noqa: F401
116-
except ImportError:
117-
pytest.skip("megatron.core.post_training not found")
118-
119113
# TODO: Fix TP>1 failures
120114
spawn_multiprocess_job(
121115
size=1, # torch.cuda.device_count(),

0 commit comments

Comments
 (0)