@@ -29,37 +29,6 @@ namespace gpu::xetla {
2929
3030namespace fmha {
3131
32- struct Shape {
33- Shape (int B, int N, int F, int T, int H)
34- : num_batches(B), num_heads(N), num_queries(F), num_keys(T),
35- head_size (H) {}
36- const int num_batches;
37- const int num_heads;
38- const int num_queries;
39- const int num_keys;
40- const int head_size;
41-
42- inline uint32_t get_query_size () const {
43- return num_batches * num_heads * num_queries * head_size;
44- }
45- inline uint32_t get_key_size () const {
46- return num_batches * num_heads * num_keys * head_size;
47- }
48- inline uint32_t get_score_size () const {
49- return num_batches * num_heads * num_queries * num_keys;
50- }
51- inline uint32_t get_ml_size () const {
52- return num_batches * num_heads * num_queries;
53- }
54- inline uint32_t get_attn_mask_size () const {
55- #if _BIAS_AS_INPUT
56- return num_batches * num_heads * num_queries * num_keys;
57- #else
58- return num_batches * num_queries * num_keys;
59- #endif
60- }
61- };
62-
6332template <typename fmha_policy, typename scalar_t , bool kUseBias ,
6433 bool kIsCausal , bool kIsTraining >
6534class fmha_forward_t {
@@ -620,46 +589,28 @@ class FmhaForwardKernel;
620589// The launcher of fmha forward kernel
621590template <typename fmha_policy, typename T, bool kUseBias = false ,
622591 bool kIsCausal = false , bool kIsTraining = false >
623- sycl::event fmha_forward_impl (sycl::queue &q, void *_q, void *_k, void *_v,
624- void *_out, void *_dropout_mask, void *_bias,
625- void *_m, void *_l, uint32_t num_batches,
626- uint32_t num_heads, uint32_t head_size,
627- uint32_t num_queries, uint32_t num_keys,
628- uint64_t seed = 0 , uint64_t offset = 123 ) {
629-
630- Shape shape (num_batches, num_heads, num_queries, num_keys, head_size);
592+ sycl::event
593+ fmha_forward_impl (sycl::queue &q, void *_q, void *_k, void *_v, void *_out,
594+ void *_dropout_mask, void *_bias, void *_m, void *_l,
595+ uint32_t num_batches, uint32_t num_heads, uint32_t head_size,
596+ uint32_t num_queries, uint32_t num_keys, float head_scale,
597+ uint64_t seed = 0 , uint64_t offset = 123 ) {
631598
632599 constexpr bool use_mask = false ;
633600 constexpr bool use_dropout = false ;
634601 float dropout_prob = 0 .0f ;
635602 if constexpr (use_dropout)
636603 dropout_prob = 0 .5f ;
637- const float scale = 1 / (1 - dropout_prob);
638- const float head_scale = sycl::rsqrt (float (head_size));
639-
640- uint32_t size_query = shape.get_query_size ();
641- uint32_t size_key = shape.get_key_size ();
642- uint32_t size_score = shape.get_score_size ();
643- uint32_t size_attn_mask = shape.get_attn_mask_size ();
644- uint32_t size_ml = shape.get_ml_size ();
645604
646605 // forward
647- // T *query = sycl::malloc_shared<T>(size_query, q);
648- // T *key = sycl::malloc_shared<T>(size_key, q);
649- // T *value = sycl::malloc_shared<T>(size_key, q);
650606 T *query = static_cast <T *>(_q);
651607 T *key = static_cast <T *>(_k);
652608 T *value = static_cast <T *>(_v);
653609
654- // T *bias = sycl::malloc_shared<T>(size_attn_mask, q);
655610 T *bias = static_cast <T *>(_bias);
656- // uint8_t *dropout_mask = sycl::malloc_shared<uint8_t>(size_score, q);
657611 uint8_t *dropout_mask = static_cast <uint8_t *>(_dropout_mask);
658- // T *out = sycl::malloc_shared<T>(size_query, q);
659612 T *out = static_cast <T *>(_out);
660- // float *m = sycl::malloc_shared<float>(size_ml, q);
661613 float *m = static_cast <float *>(_m);
662- // float *l = sycl::malloc_shared<float>(size_ml, q);
663614 float *l = static_cast <float *>(_l);
664615
665616 // fmha forward kernel
@@ -687,12 +638,6 @@ sycl::event fmha_forward_impl(sycl::queue &q, void *_q, void *_k, void *_v,
687638 fmha_fwd_op (ei, args);
688639 });
689640 });
690- // sycl::free(query, q);
691- // sycl::free(key, q);
692- // sycl::free(value, q);
693- // sycl::free(bias, q);
694- // sycl::free(dropout_mask, q);
695- // sycl::free(out, q);
696641 return event;
697642}
698643
0 commit comments