Skip to content

Commit e8532cc

Browse files
some small cleanups to the integration test after reverting
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 2a6babd commit e8532cc

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/integration/defs/examples/test_ad_speculative_decoding.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
8686
llm_args = {
8787
"model": model,
8888
"skip_loading_weights": False,
89+
"speculative_config": spec_config,
8990
"runtime": "trtllm",
9091
"world_size": 1,
9192
"kv_cache_config": kv_cache_config,
@@ -106,10 +107,6 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
106107
# Create ExperimentConfig
107108
cfg = ExperimentConfig(**experiment_config)
108109

109-
cfg.args.speculative_config = (
110-
spec_config # Add here to avoid Pydantic validation error for eagle3_layers_to_capture
111-
)
112-
113110
# Add sampling parameters (deterministic with temperature=0.0 and fixed seed)
114111
cfg.prompt.sp_kwargs = {
115112
"max_tokens": 50,
@@ -133,21 +130,20 @@ def test_autodeploy_spec_dec(batch_size):
133130
Runs with and without speculative decoding and verifies outputs are identical.
134131
"""
135132
print("\n" + "=" * 80)
136-
print(f"Testing AutoDeploy Speculative Decoding - Batch Size {batch_size}")
133+
print(f"Testing AutoDeploy Speculative Decoding (Draft Target) - Batch Size {batch_size}")
137134
print("=" * 80)
138135

139136
base_model, draft_target_model, _ = get_model_paths()
140137

141138
print(f"\nBase Model: {base_model}")
142-
spec_model_path = draft_target_model
143-
print(f"Speculative Model: {spec_model_path}")
139+
print(f"Speculative Model: {draft_target_model}")
144140
print(f"Batch Size: {batch_size}")
145141

146142
# Run with speculative decoding
147143
print("\n[1/2] Running with speculative decoding enabled...")
148144
spec_outputs = run_with_autodeploy(
149145
model=base_model,
150-
speculative_model_dir=spec_model_path,
146+
speculative_model_dir=draft_target_model,
151147
batch_size=batch_size,
152148
)
153149
print(f"Generated {len(spec_outputs)} outputs with speculative decoding")

0 commit comments

Comments
 (0)