Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
826bda0
[FA] 4-stage FA pipeliner
AlexAUT Apr 15, 2025
c35e297
[FA] Add FA scripts
AlexAUT Apr 15, 2025
06cf75a
[FA] Place cvt layout in the same stage and cluster as LocalLoad so c…
AlexAUT Apr 16, 2025
203fe11
[ASYNC_COPY] Add env var to bypass permute, only works if the load di…
AlexAUT Apr 23, 2025
b664353
[FA] Do not combine AsyncWaits to have a barrier in front of each mem…
AlexAUT Apr 23, 2025
012793a
[ASYNC_COPY] Remove MemoryEffect of BufferLoadToLocal to avoid implic…
AlexAUT Apr 23, 2025
3b74f4a
[FA] Compute max before mul QK_SCALE to fold sub into fma
AlexAUT Apr 23, 2025
b059372
[FA] Added 2 extra clusters to have async_waits in front of memory cl…
AlexAUT Apr 24, 2025
f3bb293
[FA] Place LocalLoads before AsyncCopies
AlexAUT Apr 24, 2025
77884fa
[FA][ASYNC_COPY] Force vec=8 for shared encodings to avoid 32bit dire…
AlexAUT Apr 24, 2025
b2e2ad0
[FA] Place dots at the top of clusters
AlexAUT Apr 24, 2025
fab1281
[FA] Split 4-stage clusters into 8 clusters to better controll the or…
AlexAUT Apr 24, 2025
e0ea5e7
[FA] Revert order change in SM clusters
AlexAUT Apr 24, 2025
1198462
[FA] Set vecSize=nonKDim for V shared layout to avoid bank conflicts
zhanglx13 Apr 27, 2025
fb186d4
[FA] Removed old vectorSize workaround
AlexAUT Apr 28, 2025
3212481
[FA] Revert "Place AsyncWait at the top of schedule"
AlexAUT Apr 29, 2025
34beed7
[FA][PINGPONG] Add support for FAv3 pingpong.
jungpark-mlir Apr 29, 2025
3861063
[FA][PINGPONG] Allow block pingpong with num_stages==4
AlexAUT Apr 30, 2025
fc6d1d9
[FA][PINGPONG] Bail out if async wait count != 2
AlexAUT May 2, 2025
d6a0419
[FA] Do not pipeline second loop (causal)
AlexAUT May 3, 2025
1d1e8cc
[FA] Split FourStagePipeliner to separate file and do very basic sele…
AlexAUT May 7, 2025
d46f750
[GEMM] Add combine dot_scaled and addF
jungpark-mlir May 1, 2025
4a5ece6
[GEMM] Do not swizzle the scale
AlexAUT May 6, 2025
8285bfc
Add layout conversion pass optim at the end
vgokhale May 12, 2025
916cb06
Initial commit to enable pingpong for dot_scaled with mxfp4
jungpark-mlir May 12, 2025
b3c2f94
Fix to the gemm pingpong.
jungpark-mlir May 13, 2025
a19dd6d
Add restriction to dot_scaled pingpong.
jungpark-mlir May 13, 2025
f6065b9
Revert "[AMD] Use v_permlane to optimize MFAM to linear layout on GFX…
AlexAUT May 13, 2025
d7e2e2c
Revert "[BACKEND] bump to llvm/llvm-project@3c709802d31b (#6754)"
AlexAUT May 13, 2025
247f4f4
Revert because no longer needed: "[ASYNC_COPY] Remove MemoryEffect of…
AlexAUT May 14, 2025
1028c8f
[AMD] Enable async pingpong for F16 GEMMs (#796)
raikonenfnu May 15, 2025
c5c0e67
Add initial support for skinny mxfp gemm
jungpark-mlir May 17, 2025
bcc871d
add AB load separated pingpong for skinny gemm.
jungpark-mlir May 18, 2025
1b2a86b
[AMD] add slicing `async-copy-local-to-global`
ravil-mobile May 15, 2025
33f6ce9
Revert "Revert "[AMD] Use v_permlane to optimize MFAM to linear layou…
antiagainst May 19, 2025
0f7bbc2
[AMD] Use composition to swap columns for mfma like store layout (#6844)
antiagainst May 16, 2025
aebdfd7
[ASYNCCOPY] Simplify swizzling calculations to get better codegen from
AlexAUT May 16, 2025
6527f10
Code cleanup
jungpark-mlir May 19, 2025
1082cd2
Add skinny pingpong transform
jungpark-mlir May 20, 2025
5c4b1fb
[FA] Disable pipelining for causal loop
AlexAUT May 20, 2025
18ae32b
[AMD] Add an option to force async copy overlapping
joviliast May 16, 2025
77c00fa
[AMD] Add an option to force async copy overlapping
joviliast May 21, 2025
c5ceb64
[AMD] Improved CanonicalizePointers for ExtractSlice
ravil-mobile May 19, 2025
a89b3b4
Merge branch 'shared/triton-gfx950-launch' into shared/triton-gfx950-…
ravil-mobile May 21, 2025
a981b01
[AMD] Add a Concat op to AMDGPU dialect (#6590)
plognjen May 20, 2025
6a6fb70
WA for incorrect strides in subview
AlexAUT May 21, 2025
34538bc
[AMD] improved subviewing for async-copy-local-to-global
ravil-mobile May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests-amd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
cd python/test/unit
pytest --capture=tee-sys -rfs -n 12 language runtime \
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3c709802d31b5bc5ed3af8284b40593ff39b9eec
092b6e73e651469527662443b592f98f442ece72
2,139 changes: 2,139 additions & 0 deletions fa/flash-attention.py

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions fa/model_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"llama3": {
"8B": {
"num_attention_heads": 32,
"num_key_value_heads": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"vocab_size": 128256
},
"70B": {
"num_attention_heads": 64,
"num_key_value_heads": 8,
"hidden_size": 8192,
"intermediate_size": 28672,
"vocab_size": 128256
},
"405B": {
"num_attention_heads": 128,
"num_key_value_heads": 8,
"hidden_size": 16384,
"intermediate_size": 53248,
"vocab_size": 128256
}
},
"mistral": {
"7B": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"vocab_size": 32000
},
"22B": {
"hidden_size": 6144,
"intermediate_size": 16384,
"num_attention_heads": 48,
"num_key_value_heads": 8,
"vocab_size": 32000
}

}
}
Empty file added fa/utils/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions fa/utils/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import json

# Base directory where configs are located
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))


def get_model_configs(config_path='model_configs.json', model_families=["llama3"], model="all"):
"""
Load model names from the configuration file.

Args:
config_path (str): User-provided path to the configuration JSON file.
model_families (list): List of model family names to retrieve.

Returns:
dict: A dictionary of available models and their configurations for the specified families.
"""
# Resolve config path relative to ./perf-kernels/
config_path = os.path.join(BASE_DIR, config_path)

with open(config_path, 'r') as f:
configs = json.load(f)

# Extract models and their configurations for the specified families
filtered_configs = {}

for family in model_families:
if family in configs:
# Check if model filtering is required
if model == "all":
# Include all models in the family
for model_size, model_configs in configs[family].items():
filtered_configs[f"{family}-{model_size}"] = model_configs
else:
# Parse the model string (e.g., llama3_8B or llama3-8B)
delimiter = "_" if "_" in model else "-"
model_parts = model.split(delimiter)

# Check if the family and size match
if len(model_parts) == 2 and model_parts[0] == family:
model_size = model_parts[1]
if model_size in configs[family]:
filtered_configs[f"{family}-{model_size}"] = configs[family][model_size]

if not filtered_configs:
print(f"Warning: No models selected for families: {model_families} with filter: '{model}'")

return filtered_configs


def get_available_models(config_file='model_configs.json', model_families=["llama3"]):
"""
Load model names from the configuration file.

Args:
config_file (str): Path to the configuration JSON file.
model_families (list): List of model family names to retrieve.

Returns:
list: A list of available models for the specified families.
"""
# Resolve config path relative to ./perf-kernels/
config_path = os.path.join(BASE_DIR, config_file)

with open(config_path, 'r') as f:
configs = json.load(f)

models = [f"{family}-{model}" for family in model_families if family in configs for model in configs[family]]

return models
59 changes: 59 additions & 0 deletions fa/utils/rocprof_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import subprocess
import os
import pandas as pd
from prettytable import PrettyTable


def run_profiling(triton_dir, batch_size, output_file):
command = [
"rocprof", "--stats", "-o", output_file, "python", f"{triton_dir}/python/perf-kernels/MLA_decode_rope.py", "-B",
str(batch_size), "-dtype", "bf16", "-use_rope"
]
subprocess.run(command, check=True)


def parse_profiling_output(output_file, kernel_names):
df = pd.read_csv(output_file)
results = {}
for kernel in kernel_names:
kernel_data = df[df['Name'].str.strip('"') == kernel]
if not kernel_data.empty:
results[kernel] = kernel_data['AverageNs'].iloc[0] / 1000.0
else:
results[kernel] = None

# Calculate sum of other kernels
other_kernels = df[~df['Name'].str.strip('"').isin(kernel_names)]
other_kernels_sum = other_kernels['AverageNs'].sum() / 1000.0
results['other_kernels_sum'] = other_kernels_sum

return results


def main():
triton_dir = os.environ.get("TRITONDIR", "~/triton") # Default to ~/triton if not set
output_file = os.path.expanduser("~/profiling.csv")
kernel_names = ["_fwd_grouped_kernel_stage1_rope.kd", "_fwd_grouped_kernel_stage1.kd"]
batch_sizes = [1, 4, 32, 64, 128]

results = {B: {} for B in batch_sizes}
for B in batch_sizes:
print(f"Running profiling for B={B}...")
run_profiling(triton_dir, B, output_file)
output_stats_file = os.path.expanduser("~/profiling.stats.csv")
kernel_results = parse_profiling_output(output_stats_file, kernel_names)
results[B] = kernel_results

table = PrettyTable()
table.field_names = ["B"] + kernel_names + ["Other Kernels Sum (µs)"]
for B in batch_sizes:
row = [B] + [results[B].get(kernel, "N/A")
for kernel in kernel_names] + [results[B].get('other_kernels_sum', "N/A")]
table.add_row(row)

print("\nProfiling Summary (in microseconds):")
print(table)


if __name__ == "__main__":
main()
Loading
Loading