Skip to content

Commit 36ca523

Browse files
authored
[CK_TILE] Update gfx11 FMHA forward kernel configs (#5088)
## Motivation Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded seqlen_q/k) cases. This tuning is based on heuristic search and improves performance in most tested shapes. Performance should be evaluated on top of [`#5018`](#5018) (required baseline). ## Technical Details - Updated gfx11 codegen heuristic choices for tile size and occupancy. - Updated gfx11 pipeline selection: - Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad cases are dispatched to the faster kernel path.` - Kept gfx12 unchanged: with PSSK support from [`#4957`](#4957), existing gfx12 config is already sufficient. - Tuning rationale: - In some cases, higher `kBlockPerCu` lowers register pressure. - On RDNA, this generally aligns with better performance when `waves_per_eu >= 6`. ## Test Plan - test_ck_tile_fmha - tile_example_fmha_fwd: tested this on gfx1100 and gfx1151 ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24 -d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result - TFLOPs by sequence length target: `gfx1100` layout: `bhsd` - mode: batch / VGPR usage: 225 vs 214 SeqLen | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 74.10 | 71.97 | 0.97x 4096 | 66.26 | 77.79 | 1.17x 8192 | 68.18 | 75.88 | 1.11x 12288 | 68.47 | 80.44 | 1.17x 16384 | 59.54 | 79.66 | 1.34x 20480 | 55.78 | 77.91 | 1.40x 24576 | 55.08 | 77.47 | 1.41x 27280 | 47.45 | 77.16 | 1.63x - mode: group / VGPR usage: 256 vs 214 SeqLen | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 71.47 | 70.6 | 0.99x 4096 | 64.74 | 77.06 | 1.19x 8192 | 64.68 | 75.47 | 1.17x 12288 | 66.43 | 79.95 | 1.20x 16384 | 56.02 | 79.73 | 1.42x 20480 | 50.21 | 78.15 | 1.56x 24576 | 47.29 | 77.53 | 1.64x 27280 | 46.13 | 77.04 | 1.67x - TFLOPs by sequence length target: `gfx1151` layout: `bshd` - mode: batch / VGPR usage: 225 vs 223 Batch | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 26.85 | 29.17 | 1.09x 4096 | 24.75 | 26.01 | 1.05x 8192 | 25.24 | 25.50 | 1.01x 12288 | 25.18 | 25.00 | 0.99x 16384 | 24.79 | 25.91 | 1.05x 20480 | 25.56 | 25.24 | 0.99x 24576 | 25.13 | 26.20 | 1.04x 27280 | 10.78 | 26.35 | 2.44x - mode: group / VGPR usage: 256 vs 229 Batch | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 27.44 | 26.71 | 0.97x 4096 | 21.89 | 23.09 | 1.05x 8192 | 22.85 | 24.49 | 1.07x 12288 | 24.33 | 24.42 | 1.00x 16384 | 20.05 | 24.98 | 1.24x 20480 | 14.70 | 25.15 | 1.71x 24576 | 11.30 | 26.31 | 2.33x 27280 | 10.10 | 26.32 | 2.61x ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent feda326 commit 36ca523

File tree

4 files changed

+34
-5
lines changed

4 files changed

+34
-5
lines changed

projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,10 @@ def get_pipelines(
10951095

10961096

10971097
class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
1098-
arch = ArchTrait("gfx11")
1098+
arch = ArchTrait(
1099+
"gfx11",
1100+
preprocessor_check="defined(__gfx11__) && !defined(__gfx115__)",
1101+
)
10991102

11001103
_DT_FP16_BF16 = ("fp16", "bf16")
11011104

@@ -1109,10 +1112,12 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
11091112
return {
11101113
# bm0, bn0, bk0, bn1, bk1,
11111114
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1112-
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1113-
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1115+
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
1116+
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1117+
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
1118+
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
11141119
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1115-
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1120+
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
11161121
} # fmt: skip
11171122
else:
11181123
raise ValueError(f"unsupported dtype={dtype}")
@@ -1133,12 +1138,25 @@ def get_pipelines(
11331138
["t", "f"],
11341139
["t", "f"],
11351140
):
1136-
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1141+
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
1142+
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
11371143
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
11381144
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
11391145
return pipelines
11401146

11411147

1148+
class KernelComponentFactoryGfx115(KernelComponentFactoryGfx11):
1149+
arch = ArchTrait("gfx115")
1150+
1151+
@classmethod
1152+
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
1153+
result = super().get_hdim_tile_size_dict(dtype)
1154+
if dtype in cls._DT_FP16_BF16:
1155+
result[(64, 64)] = [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip
1156+
result[(256, 256)] = [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip
1157+
return result
1158+
1159+
11421160
class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
11431161
arch = ArchTrait("gfx12")
11441162

@@ -1230,6 +1248,8 @@ def get_factory(target: str):
12301248
if target.startswith("gfx9"):
12311249
return KernelComponentFactoryGfx9
12321250

1251+
if target.startswith("gfx115"):
1252+
return KernelComponentFactoryGfx115
12331253
if target.startswith("gfx11"):
12341254
return KernelComponentFactoryGfx11
12351255
if target.startswith("gfx12"):

projects/composablekernel/include/ck_tile/core.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "ck_tile/core/arch/mma/mma_selector.hpp"
2424
#include "ck_tile/core/arch/mma/mma_traits.hpp"
2525
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
26+
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
2627
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
2728
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
2829
#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp"

projects/composablekernel/include/ck_tile/core/arch/arch.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,9 @@ struct gfx103_t
11411141
struct gfx11_t
11421142
{
11431143
};
1144+
struct gfx115_t
1145+
{
1146+
};
11441147
struct gfx12_t
11451148
{
11461149
};
@@ -1174,6 +1177,8 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
11741177

11751178
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
11761179

1180+
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx115_t) { return 32; }
1181+
11771182
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
11781183

11791184
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }

projects/composablekernel/include/ck_tile/core/config.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__)
2525
#define __gfx11__
2626
#endif
27+
#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
28+
#define __gfx115__
29+
#endif
2730
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
2831
#define __gfx12__
2932
#endif

0 commit comments

Comments
 (0)