Skip to content

Commit 97556c6

Browse files
committed
[ROCm] Resolve conflicts
1 parent 966a4ac commit 97556c6

File tree

23 files changed

+401
-1219
lines changed

23 files changed

+401
-1219
lines changed

benchmarks/attention/benchmark_attention_rocm.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
import transformer_engine
1414
from transformer_engine_torch import NVTE_Fused_Attn_Backend
1515

16-
# Add test_fused_attn to the sys path
16+
# Add paths tests/pytorch/ and tests/pytorch/attention to the sys path
1717
tests_path = os.path.abspath(
18-
os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn")
18+
os.path.join(os.path.dirname(__file__), "../../tests")
1919
)
20-
sys.path.append(tests_path)
20+
sys.path.append(tests_path + "/pytorch")
21+
sys.path.append(tests_path + "/pytorch/attention")
2122

22-
from test_fused_attn import (
23+
# Add tests/pytorch/utils.py path into sys path
24+
from utils import (
2325
ModelConfig,
24-
_get_attention_backends,
26+
get_available_attention_backends,
27+
)
28+
from test_attention import (
2529
_run_dot_product_attention,
2630
)
2731

@@ -46,12 +50,12 @@
4650
is_training = True
4751

4852
model_configs = {
49-
# test: b, h, hg, d, sq, skv, p, mask, bias
50-
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
51-
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
52-
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
53-
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
54-
"test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias")
53+
# test: b, sq, h, d
54+
"test_0": ModelConfig(2, 512, 16, 64), # short seq
55+
"test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask
56+
"test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias
57+
"test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), # GQA
58+
"test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=16, attn_mask_type="causal_bottom_right")
5559
}
5660

5761
# DataFrame indices and columns for results
@@ -303,7 +307,7 @@ def sanity_checks(
303307
}
304308

305309
for model, cfg in model_configs.items():
306-
avail, _, fused_bes = _get_attention_backends(
310+
avail, _, fused_bes = get_available_attention_backends(
307311
cfg,
308312
qkv_dtype=dtype,
309313
qkv_layout=qkv_layout,
@@ -364,7 +368,7 @@ def main(args):
364368
# Benchmarking starts..
365369
for model in model_configs.keys():
366370
config = model_configs[model]
367-
available_backends, _, fused_attn_backends = _get_attention_backends(
371+
available_backends, _, fused_attn_backends = get_available_attention_backends(
368372
config,
369373
qkv_dtype=dtype,
370374
qkv_layout=qkv_layout,

build_tools/pytorch.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,7 @@
2727

2828
def install_requirements() -> List[str]:
2929
"""Install dependencies for TE/PyTorch extensions."""
30-
<<<<<<< HEAD
31-
reqs = ["einops"]
32-
if not rocm_build():
33-
reqs.append(
34-
"nvdlfw-inspect @"
35-
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
36-
)
37-
reqs.extend(
38-
[
39-
"torch>=2.1",
40-
"onnx",
41-
"onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
42-
]
43-
)
44-
return reqs
45-
=======
4630
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]
47-
>>>>>>> upstream/release_v2.6
4831

4932

5033
def test_requirements() -> List[str]:

ci/pytorch.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ run_test_config(){
6565
run_default_fa 1 test_recipe.py
6666
run 1 test_sanity.py
6767
run_default_fa 1 test_sanity_import.py
68-
run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test
68+
run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test
6969
run_default_fa 1 triton_kernels/test_cast.py
7070
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
7171
run_default_fa 1 triton_kernels/test_norm_common.py
@@ -88,8 +88,8 @@ run_test_config_mgpu(){
8888
run_default_fa 2 distributed/test_numerics.py
8989
run_default_fa 1 distributed/test_torch_fsdp2.py
9090
run_default_fa 2 distributed/test_torch_fsdp2_fp8.py
91-
run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash"
92-
run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused"
91+
run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash"
92+
run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused"
9393
}
9494

9595
run_benchmark() {

0 commit comments

Comments
 (0)