|
10 | 10 | @pytest.mark.parametrize("world_size", [1, 2]) |
11 | 11 | @pytest.mark.parametrize("mode", ["graph", "transformers"]) |
12 | 12 | @pytest.mark.parametrize( |
13 | | - "experiment_config", |
| 13 | + "experiment_config, attn_backend, compile_backend", |
14 | 14 | [ |
15 | 15 | get_small_model_config_pytest_param( |
16 | 16 | "meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
19 | 19 | ), |
20 | 20 | ], |
21 | 21 | ) |
22 | | -def test_build_ad(world_size: int, experiment_config: Dict, mode: str): |
| 22 | +def test_build_ad( |
| 23 | + world_size: int, experiment_config: Dict, attn_backend: str, compile_backend: str, mode: str |
| 24 | +): |
23 | 25 | experiment_config["args"]["world_size"] = world_size |
24 | 26 | experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm |
25 | 27 | experiment_config["args"]["mode"] = mode |
| 28 | + experiment_config["args"]["transforms"] = ( |
| 29 | + { |
| 30 | + "resize_kv_cache": { |
| 31 | + "stage": "cache_init", |
| 32 | + "free_mem_ratio": 0.00, |
| 33 | + }, |
| 34 | + "match_attention_layout": { |
| 35 | + "stage": "pattern_matcher", |
| 36 | + "attn_backend": attn_backend, |
| 37 | + }, |
| 38 | + "insert_cached_attention": { |
| 39 | + "stage": "cache_init", |
| 40 | + "attn_backend": attn_backend, |
| 41 | + }, |
| 42 | + "compile_model": { |
| 43 | + "stage": "compile", |
| 44 | + "compile_backend": compile_backend, |
| 45 | + }, |
| 46 | + } |
| 47 | + if mode == "graph" |
| 48 | + else { |
| 49 | + "transformers_replace_cached_attn": { |
| 50 | + "stage": "cache_init", |
| 51 | + "attn_backend": attn_backend, |
| 52 | + }, |
| 53 | + } |
| 54 | + ) |
26 | 55 | experiment_config = ExperimentConfig(**experiment_config) |
27 | 56 | print(f"Experiment Config: {experiment_config}") |
28 | 57 | main(experiment_config) |
0 commit comments