|
13 | 13 | import transformer_engine |
14 | 14 | from transformer_engine_torch import NVTE_Fused_Attn_Backend |
15 | 15 |
|
16 | | -# Add test_fused_attn to the sys path |
| 16 | +# Add paths tests/pytorch/ and tests/pytorch/attention to the sys path |
17 | 17 | tests_path = os.path.abspath( |
18 | | - os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn") |
| 18 | + os.path.join(os.path.dirname(__file__), "../../tests") |
19 | 19 | ) |
20 | | -sys.path.append(tests_path) |
| 20 | +sys.path.append(tests_path + "/pytorch") |
| 21 | +sys.path.append(tests_path + "/pytorch/attention") |
21 | 22 |
|
22 | | -from test_fused_attn import ( |
| 23 | +# Add tests/pytorch/utils.py path into sys path |
| 24 | +from utils import ( |
23 | 25 | ModelConfig, |
24 | | - _get_attention_backends, |
| 26 | + get_available_attention_backends, |
| 27 | +) |
| 28 | +from test_attention import ( |
25 | 29 | _run_dot_product_attention, |
26 | 30 | ) |
27 | 31 |
|
|
46 | 50 | is_training = True |
47 | 51 |
|
48 | 52 | model_configs = { |
49 | | - # test: b, h, hg, d, sq, skv, p, mask, bias |
50 | | - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq |
51 | | - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask |
52 | | - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias |
53 | | - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA |
54 | | - "test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias") |
| 53 | + # test: b, sq, h, d |
| 54 | + "test_0": ModelConfig(2, 512, 16, 64), # short seq |
| 55 | + "test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask |
| 56 | + "test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias |
| 57 | + "test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), # GQA |
| 58 | + "test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=16, attn_mask_type="causal_bottom_right") |
55 | 59 | } |
56 | 60 |
|
57 | 61 | # DataFrame indices and columns for results |
@@ -303,7 +307,7 @@ def sanity_checks( |
303 | 307 | } |
304 | 308 |
|
305 | 309 | for model, cfg in model_configs.items(): |
306 | | - avail, _, fused_bes = _get_attention_backends( |
| 310 | + avail, _, fused_bes = get_available_attention_backends( |
307 | 311 | cfg, |
308 | 312 | qkv_dtype=dtype, |
309 | 313 | qkv_layout=qkv_layout, |
@@ -364,7 +368,7 @@ def main(args): |
364 | 368 | # Benchmarking starts.. |
365 | 369 | for model in model_configs.keys(): |
366 | 370 | config = model_configs[model] |
367 | | - available_backends, _, fused_attn_backends = _get_attention_backends( |
| 371 | + available_backends, _, fused_attn_backends = get_available_attention_backends( |
368 | 372 | config, |
369 | 373 | qkv_dtype=dtype, |
370 | 374 | qkv_layout=qkv_layout, |
|
0 commit comments