Skip to content

Commit f2b4b6d

Browse files
authored
[Benchmarks] Add more variants in XeTLA FA implementation (#2309)
More shapes and causal support.
1 parent 3f81c35 commit f2b4b6d

File tree

3 files changed

+28
-72
lines changed

3 files changed

+28
-72
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,19 @@ def forward(q, k, v, causal, sm_scale):
188188
# argument names to use as an x-axis for the plot
189189
x_names=['Z', 'H', 'N_CTX', 'D_HEAD'],
190190
x_vals=[ #
191+
[1, 16, 16384, 128], #
191192
[1, 32, 16384, 64], #
193+
[2, 16, 8192, 128], #
192194
[2, 32, 8192, 64], #
195+
[4, 16, 4096, 128], #
193196
[4, 32, 4096, 64], #
194197
[4, 48, 1024, 64], #
198+
[8, 16, 2048, 128], #
195199
[8, 32, 2048, 64], #
200+
[16, 16, 1024, 128], #
196201
[16, 32, 1024, 64], #
197-
[32, 32, 512, 64] #
202+
[32, 16, 512, 128], #
203+
[32, 32, 512, 64], #
198204
],
199205
line_arg='provider',
200206
# argument name whose value corresponds to a different line in the plot
@@ -238,7 +244,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
238244
fast_flush=False)
239245

240246
elif provider == 'xetla':
241-
func = getattr(xetla_kernel, 'flash_attn')
247+
module_name = f'flash_attn_causal_{causal}'.lower()
248+
func = getattr(xetla_kernel, module_name)
242249
out = torch.empty_like(q, device='xpu', dtype=dtype)
243250
size_score = Z * H * N_CTX * N_CTX
244251
size_attn_mask = Z * N_CTX * N_CTX
@@ -248,7 +255,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
248255
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
249256
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
250257

251-
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX)
258+
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
252259
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
253260
fast_flush=False)
254261

benchmarks/xetla_kernel/flash_attention/fmha_forward_v5.h

Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -29,37 +29,6 @@ namespace gpu::xetla {
2929

3030
namespace fmha {
3131

32-
struct Shape {
33-
Shape(int B, int N, int F, int T, int H)
34-
: num_batches(B), num_heads(N), num_queries(F), num_keys(T),
35-
head_size(H) {}
36-
const int num_batches;
37-
const int num_heads;
38-
const int num_queries;
39-
const int num_keys;
40-
const int head_size;
41-
42-
inline uint32_t get_query_size() const {
43-
return num_batches * num_heads * num_queries * head_size;
44-
}
45-
inline uint32_t get_key_size() const {
46-
return num_batches * num_heads * num_keys * head_size;
47-
}
48-
inline uint32_t get_score_size() const {
49-
return num_batches * num_heads * num_queries * num_keys;
50-
}
51-
inline uint32_t get_ml_size() const {
52-
return num_batches * num_heads * num_queries;
53-
}
54-
inline uint32_t get_attn_mask_size() const {
55-
#if _BIAS_AS_INPUT
56-
return num_batches * num_heads * num_queries * num_keys;
57-
#else
58-
return num_batches * num_queries * num_keys;
59-
#endif
60-
}
61-
};
62-
6332
template <typename fmha_policy, typename scalar_t, bool kUseBias,
6433
bool kIsCausal, bool kIsTraining>
6534
class fmha_forward_t {
@@ -620,46 +589,28 @@ class FmhaForwardKernel;
620589
// The launcher of fmha forward kernel
621590
template <typename fmha_policy, typename T, bool kUseBias = false,
622591
bool kIsCausal = false, bool kIsTraining = false>
623-
sycl::event fmha_forward_impl(sycl::queue &q, void *_q, void *_k, void *_v,
624-
void *_out, void *_dropout_mask, void *_bias,
625-
void *_m, void *_l, uint32_t num_batches,
626-
uint32_t num_heads, uint32_t head_size,
627-
uint32_t num_queries, uint32_t num_keys,
628-
uint64_t seed = 0, uint64_t offset = 123) {
629-
630-
Shape shape(num_batches, num_heads, num_queries, num_keys, head_size);
592+
sycl::event
593+
fmha_forward_impl(sycl::queue &q, void *_q, void *_k, void *_v, void *_out,
594+
void *_dropout_mask, void *_bias, void *_m, void *_l,
595+
uint32_t num_batches, uint32_t num_heads, uint32_t head_size,
596+
uint32_t num_queries, uint32_t num_keys, float head_scale,
597+
uint64_t seed = 0, uint64_t offset = 123) {
631598

632599
constexpr bool use_mask = false;
633600
constexpr bool use_dropout = false;
634601
float dropout_prob = 0.0f;
635602
if constexpr (use_dropout)
636603
dropout_prob = 0.5f;
637-
const float scale = 1 / (1 - dropout_prob);
638-
const float head_scale = sycl::rsqrt(float(head_size));
639-
640-
uint32_t size_query = shape.get_query_size();
641-
uint32_t size_key = shape.get_key_size();
642-
uint32_t size_score = shape.get_score_size();
643-
uint32_t size_attn_mask = shape.get_attn_mask_size();
644-
uint32_t size_ml = shape.get_ml_size();
645604

646605
// forward
647-
// T *query = sycl::malloc_shared<T>(size_query, q);
648-
// T *key = sycl::malloc_shared<T>(size_key, q);
649-
// T *value = sycl::malloc_shared<T>(size_key, q);
650606
T *query = static_cast<T *>(_q);
651607
T *key = static_cast<T *>(_k);
652608
T *value = static_cast<T *>(_v);
653609

654-
// T *bias = sycl::malloc_shared<T>(size_attn_mask, q);
655610
T *bias = static_cast<T *>(_bias);
656-
// uint8_t *dropout_mask = sycl::malloc_shared<uint8_t>(size_score, q);
657611
uint8_t *dropout_mask = static_cast<uint8_t *>(_dropout_mask);
658-
// T *out = sycl::malloc_shared<T>(size_query, q);
659612
T *out = static_cast<T *>(_out);
660-
// float *m = sycl::malloc_shared<float>(size_ml, q);
661613
float *m = static_cast<float *>(_m);
662-
// float *l = sycl::malloc_shared<float>(size_ml, q);
663614
float *l = static_cast<float *>(_l);
664615

665616
// fmha forward kernel
@@ -687,12 +638,6 @@ sycl::event fmha_forward_impl(sycl::queue &q, void *_q, void *_k, void *_v,
687638
fmha_fwd_op(ei, args);
688639
});
689640
});
690-
// sycl::free(query, q);
691-
// sycl::free(key, q);
692-
// sycl::free(value, q);
693-
// sycl::free(bias, q);
694-
// sycl::free(dropout_mask, q);
695-
// sycl::free(out, q);
696641
return event;
697642
}
698643

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
9494
return acc;
9595
}
9696

97-
#define CALL_IMPL_ATTENTION_FUNC(P) \
97+
#define CALL_IMPL_ATTENTION_FWD_FUNC(P) \
9898
fmha::fmha_forward_impl<P, T, use_mask, IsCausal, use_dropout>( \
9999
queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \
100100
dropout_mask.data_ptr(), bias.data_ptr(), m.data_ptr(), l.data_ptr(), \
101-
num_batches, num_heads, head_size, num_queries, num_keys)
101+
num_batches, num_heads, head_size, num_queries, num_keys, head_scale)
102102

103103
template <bool use_mask = false, bool IsCausal = false,
104104
bool use_dropout = false>
@@ -107,7 +107,8 @@ void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
107107
const at::Tensor &bias, const at::Tensor &m,
108108
const at::Tensor &l, const int64_t num_batches,
109109
const int64_t num_heads, const int64_t head_size,
110-
const int64_t num_queries, const int64_t num_keys) {
110+
const int64_t num_queries, const int64_t num_keys,
111+
float head_scale) {
111112

112113
CHECK_INPUT(q);
113114
CHECK_INPUT(k);
@@ -126,14 +127,14 @@ void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
126127

127128
sycl::event evt;
128129
if (head_size <= 64) {
129-
evt = CALL_IMPL_ATTENTION_FUNC(fmha_policy_64x128x64);
130+
evt = CALL_IMPL_ATTENTION_FWD_FUNC(fmha_policy_64x128x64);
130131
} else if (head_size <= 128) {
131-
evt = CALL_IMPL_ATTENTION_FUNC(fmha_policy_64x128x128);
132+
evt = CALL_IMPL_ATTENTION_FWD_FUNC(fmha_policy_64x128x128);
132133
} else if (head_size <= 25) {
133134
if (num_keys <= 256) {
134-
evt = CALL_IMPL_ATTENTION_FUNC(fmha_policy_32x256x256);
135+
evt = CALL_IMPL_ATTENTION_FWD_FUNC(fmha_policy_32x256x256);
135136
} else {
136-
evt = CALL_IMPL_ATTENTION_FUNC(fmha_policy_64x512x256);
137+
evt = CALL_IMPL_ATTENTION_FWD_FUNC(fmha_policy_64x512x256);
137138
}
138139
} else {
139140
std::cout << "No policy available for current head_size " << head_size
@@ -213,5 +214,8 @@ PYBIND11_MODULE(xetla_kernel, m) {
213214
m.def("gemm_shape_4096_8_16384_128",
214215
&bf16_gemm<Test_4096x8x16384x128_row_row>, "bf16_gemm (XeTLA)");
215216
// flash_attn
216-
m.def("flash_attn", &flash_attn<false, false, false>, "flash attn (XeTLA)");
217+
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
218+
"flash attn fwd (XeTLA)");
219+
m.def("flash_attn_causal_true", &flash_attn<false, true, false>,
220+
"flash attn fwd (XeTLA)");
217221
}

0 commit comments

Comments
 (0)