Skip to content

Commit 26ce12c

Browse files
authored
Swap FFA backward QK loop (#204)
* added vis_cute_layout utils func * minor updated vis_cute_layout utils func * added debug print code for fwd * added debug print code for bwd * fixed fwd debug print * added swap_bwd_qk_loop to ffa args and adjusted the order of args and updated the docstring and updated the usage of ffa fwd/bwd funcs * added SwapBwdQKLoop template into ffa bwd; passed swap_bwd_qk_loop flag through ffa jit system * inited up the framework for ffa_bwd loop k * moved static_switch.h and utils.h to ffa sub-dir * added canonical_warp_idx_in_warpgroup sync/nosync to utils.h * added canonical_thread_idx_in_warpgroup nosync/sync utils func * added sync_cga_threads to utils; inited shared storage and pipeline init * inited the bwd schedule func; added BwdNamedBarriersLoopK * renamed funcs for loop-q * partial implemented load_with_loop_k; added get_tma_multi_cast_meta, sizeof_bytes_v utils funcs * simplified the usage of TileShape_MNK * simplified the usage of cutlass::gemm::collective::detail * inited the smem layout for swap_qk_loop; polished much * inited the params for swap_qk_loop * removed runtime debug code * minor polished load func for fwd/bwd * refactored load func for fwd/bwd * implemented the load func (i.e. producer) for swap_qk_loop * rich-commented scheduling code for fwd/bwd; inited the consumer scheduling code for swap_qk_loop * removed all the left debug code * minor polished bwd mma func * added dKV acc smem layout and tensor storage * updated bwd args when switching swap qk loop; unified /* DEBUG */ tag * added some softmax helper funcs; half-implemented mma func until P,dS * repolished the comments of bwd mma func * implemented dv gemm * rich-commented the second half of bwd mma when not swap_qk_loop * finished mma func for swap_qk_loop w/o Slice_dQKV_Mma * implemented mma func when swap_qk_loop with Slice_dQKV_Mma * removed temp debug signature * implemented store_kv func * implemented epilogue store funcs * make auto_range_merge a jit template parameter * minor fixed some compilation error * found the named barriers issue and added static assertation * implemented barrier manager * added barrier traits and extended named barrier sync/arrive with raw barrier IDs; deleted flash::named_barrier_xxx APIs * fixed another out-of-smem-limit issue from {64, 128, 128} to {64, 64, 128} * minor fixed kwargs of Seqlenk_mask * fixed scheduler_args from k ranges to q ranges when swap_qk_loop * minor fixed a typo * minor fixed layout idx but left fixme when not using tma * added temp debug code to align the tile size with before, reducing shared memory usage by ignoring dk * fixed the store dkv bidh -> bidh_kv bug * updated temp debug code to align the tile size with before, reducing shared memory usage by ignoring dv, instead of ignoring dk * fixed hung bug for get_lse_scaled when not all lanes get into it * removed softcap=True from prebuild to shorten compilation time * added temp debug code for test ffa * removed all the debug code and adjusted the atom layouts * fixed the missing SwapBwdQKLoop template param for tile_size_bwd_sm90 in prepare_mha_bwd * removed debug print code * removed all remaining debug code and added swap_bwd_qk_loop to test_ffa * fixed the disable_bwd_dkv_atomic_reduction to be only enabled with MHA * optimized check_mask_lse only for last m block job of each batch * minor polished tile scheduler and added count_in_warp utils func * added swabwdpqkloop dense benchmark * added merge csv utils script * refactored store dq,dk,dv * make producer storer to 2 warps; added some static assertion for swap qk loop * renamed atom layout * added one comment for NumProducerThreads * updated merge_csv utils script * renamed the file name due to typo * merged main as one single commit * fixed a missing arg * adjusted the ffa arg order * minor polished ffa fwd code * minor fixed * updated prebuild logics in setup.py
1 parent a77026d commit 26ce12c

37 files changed

+4085
-1384
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# magi_attention
22
magi_attention/_version.py
3-
magi_attention/flex_flash_attn*
43
magi_attention/csrc/comm/grpcoll/instantiations/
54
*.nsys-rep
65
*.ncu-rep

exps/attn/merge_csv.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2025-2026 SandAI. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
17+
18+
def merge_multiple_sources_to_dst(
19+
src_paths: list[str],
20+
dst_path: str,
21+
output_path: str,
22+
extracted_columns: dict[str, list[str]] | None = None,
23+
order_columns: list[str] | None = None,
24+
join_key: str = "seqlen",
25+
):
26+
"""
27+
Merges unique columns from multiple source CSVs into one destination CSV.
28+
29+
Args:
30+
src_paths (list[str]): List of paths to source CSV files.
31+
dst_path (str): Path to the destination (base) CSV file.
32+
output_path (str): Path where the merged CSV will be saved.
33+
extracted_columns (dict[str, list[str]] | None): Dictionary specifying which columns to extract from each source.
34+
order_columns (list[str] | None): List of strings to specify the preferred column order.
35+
join_key (str): The common column name used to align rows (default is 'seqlen').
36+
"""
37+
# 1. Load the base destination CSV
38+
df_final = pd.read_csv(dst_path)
39+
print(f"Loaded base destination: {dst_path}")
40+
41+
# 2. Iterate through each source and merge new columns
42+
for src_p in src_paths:
43+
df_src = pd.read_csv(src_p)
44+
45+
# Extract new columns to add from source
46+
if extracted_columns is not None and src_p in extracted_columns:
47+
new_cols = extracted_columns[src_p]
48+
else: # If not given, auto-detect new columns
49+
# Identify columns in current src that are NOT in the current merged result
50+
new_cols = [col for col in df_src.columns if col not in df_final.columns]
51+
52+
if new_cols:
53+
print(f"Adding columns {new_cols} from {src_p}")
54+
# Subset src to include only join_key and the new unique columns
55+
src_subset = df_src[[join_key] + new_cols]
56+
57+
# Left join ensures we keep all rows from the original destination table
58+
df_final = pd.merge(df_final, src_subset, on=join_key, how="left")
59+
else:
60+
print(f"No new columns found in {src_p}, skipping.")
61+
62+
# 3. Handle column reordering
63+
if order_columns:
64+
final_column_sequence = []
65+
66+
# Force the join_key to be the first column if it's not in the order list
67+
if join_key not in order_columns:
68+
final_column_sequence.append(join_key)
69+
70+
final_column_sequence.extend(order_columns)
71+
72+
# Filter: Keep only columns that actually exist in the merged dataframe
73+
existing_ordered_cols = [
74+
c for c in final_column_sequence if c in df_final.columns
75+
]
76+
77+
# Collect any remaining columns that were not specified in the custom order
78+
remaining_cols = [c for c in df_final.columns if c not in existing_ordered_cols]
79+
80+
# Apply the final column order
81+
df_final = df_final[existing_ordered_cols + remaining_cols]
82+
83+
# 4. Save the final merged table
84+
df_final.to_csv(output_path, index=False)
85+
print(f"\nAll sources merged successfully! Result saved to: {output_path}")
86+
87+
88+
# --- Example Usage ---
89+
if __name__ == "__main__":
90+
# List of your source files
91+
source_files = [
92+
"fa3_ffa.csv",
93+
"cudnn_fa4.csv",
94+
]
95+
96+
# The base destination file
97+
dst_path = "sdpa.csv"
98+
99+
# The output file path
100+
output_path = "sdpa_fa3_fa4_ffa_cudnn.csv"
101+
102+
# Preferred order for specific columns
103+
preferred_order = ["sdpa", "sdpa", "fa3", "fa4", "ffa", "cudnn"]
104+
105+
merge_multiple_sources_to_dst(
106+
src_paths=source_files,
107+
dst_path=dst_path,
108+
output_path=output_path,
109+
order_columns=preferred_order,
110+
join_key="seqlen",
111+
)

0 commit comments

Comments
 (0)