Skip to content

Commit 494d73f

Browse files
authored
FlashMask v3 support headdim > 128 (PaddlePaddle#76365)
* fm3 support head dim in range (128, 256] * update fa submodule * fix codestyle
1 parent 3218241 commit 494d73f

File tree

4 files changed

+11
-16
lines changed

4 files changed

+11
-16
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,20 +1111,15 @@ void FlashMaskV2GradBaseKernel(
11111111
return {64, 64};
11121112
}
11131113
}
1114-
} else if (head_size_rounded <= 192) {
1115-
// umiswing: head dim > 128 is not supported now
1116-
PADDLE_THROW(
1117-
common::errors::Unimplemented("head dim is rounded to %d, which is "
1118-
"not supported in FlashMask V3 now.",
1119-
head_size_rounded));
1120-
return {0, 0};
11211114
} else if (head_size_rounded <= 256) {
1122-
// umiswing: head dim > 128 is not supported now
1123-
PADDLE_THROW(
1124-
common::errors::Unimplemented("head dim is rounded to %d, which is "
1125-
"not supported in FlashMask V3 now.",
1126-
head_size_rounded));
1127-
return {0, 0};
1115+
// umiswing: by now, we reuse template instantiation of head dim 256 for
1116+
// head dim in range (128, 256], and therefore no separate dispatch for
1117+
// head dim in range (128, 192]
1118+
if (has_lt_end && has_ut_start) {
1119+
return {64, 32};
1120+
} else {
1121+
return {64, 64};
1122+
}
11281123
} else {
11291124
PADDLE_THROW(
11301125
common::errors::Unimplemented("head dim is rounded to %d, which is "

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,7 +1396,7 @@ void FlashMaskV2BaseKernel(
13961396
common::errors::InvalidArgument(
13971397
"batch_size must be equal to batch_size_k"));
13981398
}
1399-
int const max_headdim = std::min(flashmaskv2_get_max_headdim(), 128);
1399+
int const max_headdim = flashmaskv2_get_max_headdim();
14001400
PADDLE_ENFORCE_LE(
14011401
head_size,
14021402
max_headdim,

paddle/phi/kernels/gpu/flash_attn_v3_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ inline int get_max_headdim() {
6868
return 0;
6969
}
7070

71-
inline int flashmaskv2_get_max_headdim() { return 128; }
71+
inline int flashmaskv2_get_max_headdim() { return 256; }
7272

7373
inline int round_up_headdim(int head_size) {
7474
#ifndef FLASHATTENTION_DISABLE_HDIM64

0 commit comments

Comments
 (0)