Skip to content

Commit 860b4e4

Browse files
committed
update multigpu tests
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 01f0836 commit 860b4e4

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@pytest.mark.parametrize("world_size", [1, 2])
1111
@pytest.mark.parametrize("mode", ["graph", "transformers"])
1212
@pytest.mark.parametrize(
13-
"experiment_config",
13+
"experiment_config, attn_backend, compile_backend",
1414
[
1515
get_small_model_config_pytest_param(
1616
"meta-llama/Meta-Llama-3.1-8B-Instruct",
@@ -19,10 +19,39 @@
1919
),
2020
],
2121
)
22-
def test_build_ad(world_size: int, experiment_config: Dict, mode: str):
22+
def test_build_ad(
23+
world_size: int, experiment_config: Dict, attn_backend: str, compile_backend: str, mode: str
24+
):
2325
experiment_config["args"]["world_size"] = world_size
2426
experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm
2527
experiment_config["args"]["mode"] = mode
28+
experiment_config["args"]["transforms"] = (
29+
{
30+
"resize_kv_cache": {
31+
"stage": "cache_init",
32+
"free_mem_ratio": 0.00,
33+
},
34+
"match_attention_layout": {
35+
"stage": "pattern_matcher",
36+
"attn_backend": attn_backend,
37+
},
38+
"insert_cached_attention": {
39+
"stage": "cache_init",
40+
"attn_backend": attn_backend,
41+
},
42+
"compile_model": {
43+
"stage": "compile",
44+
"compile_backend": compile_backend,
45+
},
46+
}
47+
if mode == "graph"
48+
else {
49+
"transformers_replace_cached_attn": {
50+
"stage": "cache_init",
51+
"attn_backend": attn_backend,
52+
},
53+
}
54+
)
2655
experiment_config = ExperimentConfig(**experiment_config)
2756
print(f"Experiment Config: {experiment_config}")
2857
main(experiment_config)

0 commit comments

Comments
 (0)