Skip to content

Commit e532f3a

Browse files
authored
Merge branch 'ROCm:main' into main
2 parents e0c5114 + b5ade22 commit e532f3a

File tree

187 files changed

+11645
-3242
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

187 files changed

+11645
-3242
lines changed

.github/scripts/build_aiter_triton.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pip install --upgrade pandas zmq einops numpy==1.26.2
1212
pip uninstall -y aiter || true
1313
pip install --upgrade "pybind11>=3.0.1"
1414
pip install --upgrade "ninja>=1.11.1"
15+
pip install tabulate
1516
python3 setup.py develop
1617

1718
# Read BUILD_TRITON env var, default to 1. If 1, install Triton; if 0, skip installation.

.github/scripts/collect_logs.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python3
2+
import re
3+
import sys
4+
from pathlib import Path
5+
6+
7+
def extract_markdown_blocks(path: Path):
8+
"""
9+
Extract markdown blocks from a log file.
10+
The blocks are defined as:
11+
[aiter] <operator> summary (markdown):
12+
| ... |
13+
| ... |
14+
...
15+
"""
16+
17+
start_pattern = re.compile(r"^\[aiter\]\s+.*summary\s*\(markdown\):")
18+
table_line_pattern = re.compile(r"^\|")
19+
blocks = []
20+
21+
with path.open("r", encoding="utf-8", errors="ignore") as f:
22+
in_block = False
23+
current_block = []
24+
25+
for line in f:
26+
stripped = line.rstrip("\n")
27+
28+
if not in_block:
29+
if start_pattern.match(stripped):
30+
in_block = True
31+
current_block = [stripped]
32+
continue
33+
else:
34+
if table_line_pattern.match(stripped):
35+
current_block.append(stripped)
36+
continue
37+
else:
38+
blocks.append(current_block)
39+
in_block = False
40+
current_block = []
41+
42+
if in_block and current_block:
43+
blocks.append(current_block)
44+
45+
return blocks
46+
47+
48+
def main():
49+
if len(sys.argv) < 2:
50+
print("Usage: collect_logs.py <log_file>", file=sys.stderr)
51+
sys.exit(1)
52+
53+
log_path = Path(sys.argv[1])
54+
55+
if not log_path.exists():
56+
print(f"File not found: {log_path}", file=sys.stderr)
57+
sys.exit(1)
58+
59+
blocks = extract_markdown_blocks(log_path)
60+
61+
for i, block in enumerate(blocks):
62+
for line in block:
63+
print(line)
64+
if i != len(blocks) - 1:
65+
print()
66+
67+
68+
if __name__ == "__main__":
69+
main()

.github/workflows/aiter-test.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,14 @@ jobs:
133133
-w /workspace \
134134
aiter_test \
135135
bash -c "MAX_JOBS=20 ./.github/scripts/aiter_test.sh"
136-
136+
137+
- name: Collect test logs
138+
if: always()
139+
run: |
140+
echo "Collecting test logs..."
141+
echo "Aiter Operator Tests Summary:" >> $GITHUB_STEP_SUMMARY
142+
python3 ./.github/scripts/collect_logs.py latest_test.log >> $GITHUB_STEP_SUMMARY
143+
137144
- name: Upload test logs
138145
uses: actions/upload-artifact@v4
139146
if: always()
@@ -242,4 +249,3 @@ jobs:
242249
if: always()
243250
run: |
244251
./.github/scripts/clean_up_rocm.sh
245-

.github/workflows/operators-tuning.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ jobs:
103103
./.github/scripts/op_tune.sh test "${{ github.event.inputs.shapes }}"
104104
105105
- name: Upload tuned CSVs
106+
if: always()
106107
uses: actions/upload-artifact@v4
107108
with:
108109
name: tuned-csvs

3rdparty/composable_kernel

Submodule composable_kernel updated 70 files

aiter/__init__.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,49 +43,50 @@ def getLogger():
4343

4444
logger = getLogger()
4545

46-
47-
from .jit import core as core
48-
from .utility import dtypes as dtypes
49-
from .ops.enum import *
50-
from .ops.norm import *
51-
from .ops.quant import *
52-
from .ops.gemm_op_a8w8 import *
53-
from .ops.gemm_op_a16w16 import *
54-
from .ops.gemm_op_a4w4 import *
55-
from .ops.batched_gemm_op_a8w8 import *
56-
from .ops.batched_gemm_op_bf16 import *
57-
from .ops.deepgemm import *
58-
from .ops.aiter_operator import *
59-
from .ops.activation import *
60-
from .ops.attention import *
61-
from .ops.custom import *
62-
from .ops.custom_all_reduce import *
63-
from .ops.quick_all_reduce import *
64-
from .ops.moe_op import *
65-
from .ops.moe_sorting import *
66-
from .ops.pos_encoding import *
67-
from .ops.cache import *
68-
from .ops.rmsnorm import *
69-
from .ops.communication import *
70-
from .ops.rope import *
71-
from .ops.topk import *
72-
from .ops.topk_plain import topk_plain
73-
from .ops.mha import *
74-
from .ops.gradlib import *
75-
from .ops.trans_ragged_layout import *
76-
from .ops.sample import *
77-
from .ops.fused_mrope_rms import *
78-
from . import mla
46+
from .jit import core as core # noqa: E402
47+
from .utility import dtypes as dtypes # noqa: E402
48+
from .ops.enum import * # noqa: F403,E402
49+
from .ops.norm import * # noqa: F403,E402
50+
from .ops.quant import * # noqa: F403,E402
51+
from .ops.gemm_op_a8w8 import * # noqa: F403,E402
52+
from .ops.gemm_op_a16w16 import * # noqa: F403,E402
53+
from .ops.gemm_op_a4w4 import * # noqa: F403,E402
54+
from .ops.batched_gemm_op_a8w8 import * # noqa: F403,E402
55+
from .ops.batched_gemm_op_bf16 import * # noqa: F403,E402
56+
from .ops.deepgemm import * # noqa: F403,E402
57+
from .ops.aiter_operator import * # noqa: F403,E402
58+
from .ops.activation import * # noqa: F403,E402
59+
from .ops.attention import * # noqa: F403,E402
60+
from .ops.custom import * # noqa: F403,E402
61+
from .ops.custom_all_reduce import * # noqa: F403,E402
62+
from .ops.quick_all_reduce import * # noqa: F403,E402
63+
from .ops.moe_op import * # noqa: F403,E402
64+
from .ops.moe_sorting import * # noqa: F403,E402
65+
from .ops.pos_encoding import * # noqa: F403,E402
66+
from .ops.cache import * # noqa: F403,E402
67+
from .ops.rmsnorm import * # noqa: F403,E402
68+
from .ops.communication import * # noqa: F403,E402
69+
from .ops.rope import * # noqa: F403,E402
70+
from .ops.topk import * # noqa: F403,E402
71+
from .ops.topk_plain import topk_plain # noqa: F403,F401,E402
72+
from .ops.mha import * # noqa: F403,E402
73+
from .ops.gradlib import * # noqa: F403,E402
74+
from .ops.trans_ragged_layout import * # noqa: F403,E402
75+
from .ops.sample import * # noqa: F403,E402
76+
from .ops.fused_mrope_rms import * # noqa: F403,E402
77+
from .ops.fused_qk_norm_rope_cache_quant import * # noqa: F403,E402
78+
from .ops.groupnorm import * # noqa: F403,E402
79+
from . import mla # noqa: F403,F401,E402
7980

8081
# Import Triton-based communication primitives from ops.triton.comms (optional, only if Iris is available)
8182
try:
8283
from .ops.triton.comms import (
83-
IrisCommContext,
84-
calculate_heap_size,
85-
reduce_scatter,
86-
all_gather,
87-
reduce_scatter_rmsnorm_quant_all_gather,
88-
IRIS_COMM_AVAILABLE,
84+
IrisCommContext, # noqa: F401
85+
calculate_heap_size, # noqa: F401
86+
reduce_scatter, # noqa: F401
87+
all_gather, # noqa: F401
88+
reduce_scatter_rmsnorm_quant_all_gather, # noqa: F401
89+
IRIS_COMM_AVAILABLE, # noqa: F401
8990
)
9091
except ImportError:
9192
# Iris not available, skip import

aiter/aot/sampling.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from collections import namedtuple
2+
import os
3+
import concurrent.futures
4+
from csrc.cpp_itfs.sampling.top_k_renorm_probs import (
5+
compile as top_k_renorm_probs_compile,
6+
)
7+
from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import (
8+
compile as top_p_sampling_from_probs_compile,
9+
)
10+
from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import (
11+
compile as top_k_top_p_sampling_from_probs_compile,
12+
)
13+
14+
TopKRenormConfig = namedtuple(
15+
"TopKRenormConfig",
16+
["vec_size", "func_name"],
17+
)
18+
19+
TopPSamplingConfig = namedtuple(
20+
"TopPSamplingConfig",
21+
["vec_size", "deterministic", "func_name"],
22+
)
23+
24+
TopKTopPSamplingConfig = namedtuple(
25+
"TopKTopPSamplingConfig",
26+
["vec_size", "deterministic", "func_name"],
27+
)
28+
29+
30+
def process_top_k_renorm_config(config):
31+
return top_k_renorm_probs_compile(config.vec_size)
32+
33+
34+
def process_top_p_sampling_config(config):
35+
return top_p_sampling_from_probs_compile(config.vec_size, config.deterministic)
36+
37+
38+
def process_top_k_top_p_sampling_config(config):
39+
return top_k_top_p_sampling_from_probs_compile(
40+
config.vec_size, config.deterministic
41+
)
42+
43+
44+
def main():
45+
# Generate configs for top_k_renorm_probs
46+
top_k_renorm_configs = []
47+
for vec_size in range(1, 5):
48+
top_k_renorm_configs.append(
49+
TopKRenormConfig(
50+
vec_size=vec_size,
51+
func_name="top_k_renorm_probs",
52+
)
53+
)
54+
55+
# Generate configs for top_p_sampling_from_probs
56+
top_p_sampling_configs = []
57+
for vec_size in range(1, 5):
58+
for deterministic in [False, True]:
59+
top_p_sampling_configs.append(
60+
TopPSamplingConfig(
61+
vec_size=vec_size,
62+
deterministic=deterministic,
63+
func_name="top_p_sampling_from_probs",
64+
)
65+
)
66+
67+
# Generate configs for top_k_top_p_sampling_from_probs
68+
top_k_top_p_sampling_configs = []
69+
for vec_size in range(1, 5):
70+
for deterministic in [False, True]:
71+
top_k_top_p_sampling_configs.append(
72+
TopKTopPSamplingConfig(
73+
vec_size=vec_size,
74+
deterministic=deterministic,
75+
func_name="top_k_top_p_sampling_from_probs",
76+
)
77+
)
78+
79+
max_jobs = int(os.environ.get("MAX_JOBS", os.cpu_count() or 16))
80+
81+
# Process all configs in parallel
82+
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as executor:
83+
executor.map(process_top_k_renorm_config, top_k_renorm_configs)
84+
executor.map(process_top_p_sampling_config, top_p_sampling_configs)
85+
executor.map(process_top_k_top_p_sampling_config, top_k_top_p_sampling_configs)
86+
87+
88+
if __name__ == "__main__":
89+
main()

0 commit comments

Comments
 (0)