1+ // Copyright (C) 2024 Zhiyuan Li
2+
3+
4+ #include < sycl/sycl.hpp>
5+ #include " wkv6.hpp"
6+
7+ constexpr int WKV_BLOCK_SIZE = 64 ; // Matching CUDA_WKV_BLOCK_SIZE
8+
9+ // Helper function for the main kernel
10+ static void rwkv_wkv_f32_kernel (
11+ const int B, const int T, const int C, const int H,
12+ const float * k, const float * v, const float * r,
13+ const float * tf, const float * td, const float * s,
14+ float * dst, const sycl::nd_item<3 >& item_ct1, float * shared_mem) {
15+
16+ const int tid = item_ct1.get_local_id (2 );
17+ const int bid = item_ct1.get_group (2 );
18+
19+ const int head_size = WKV_BLOCK_SIZE;
20+ const int batch_i = bid / H;
21+ const int head_i = bid % H;
22+ const int state_size = C * head_size;
23+ const int n_seq_tokens = T / B;
24+
25+ // Set up shared memory pointers
26+ float * _k = shared_mem;
27+ float * _r = _k + head_size;
28+ float * _tf = _r + head_size;
29+ float * _td = _tf + head_size;
30+
31+ // Local state array
32+ float state[WKV_BLOCK_SIZE];
33+
34+ // Load initial state
35+ #pragma unroll
36+ for (int i = 0 ; i < head_size; i++) {
37+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
38+ }
39+
40+ // Sync threads before shared memory operations
41+ item_ct1.barrier (sycl::access::fence_space::local_space);
42+
43+ // Load time-mixing parameters
44+ _tf[tid] = tf[head_i * head_size + tid];
45+ item_ct1.barrier (sycl::access::fence_space::local_space);
46+
47+ // Main sequence processing loop
48+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
49+ t < (batch_i + 1 ) * n_seq_tokens * C + head_i * head_size + tid;
50+ t += C) {
51+
52+ item_ct1.barrier (sycl::access::fence_space::local_space);
53+
54+ // Load current timestep data to shared memory
55+ _k[tid] = k[t];
56+ _r[tid] = r[t];
57+ _td[tid] = td[t];
58+
59+ item_ct1.barrier (sycl::access::fence_space::local_space);
60+
61+ const float _v = v[t];
62+ float y = 0 ;
63+
64+ // Process in chunks of 4 for better vectorization
65+ #pragma unroll
66+ for (int j = 0 ; j < head_size; j += 4 ) {
67+ // Load data in vec4 chunks
68+ sycl::float4 k4 (_k[j], _k[j+1 ], _k[j+2 ], _k[j+3 ]);
69+ sycl::float4 r4 (_r[j], _r[j+1 ], _r[j+2 ], _r[j+3 ]);
70+ sycl::float4 tf4 (_tf[j], _tf[j+1 ], _tf[j+2 ], _tf[j+3 ]);
71+ sycl::float4 td4 (_td[j], _td[j+1 ], _td[j+2 ], _td[j+3 ]);
72+ sycl::float4 s4 (state[j], state[j+1 ], state[j+2 ], state[j+3 ]);
73+
74+ // Compute key-value product
75+ sycl::float4 kv4 = k4 * _v;
76+
77+ // Accumulate weighted sum
78+ y += sycl::dot (r4, tf4 * kv4 + s4);
79+
80+ // Update state
81+ s4 = s4 * td4 + kv4;
82+
83+ // Store updated state
84+ state[j] = s4.x ();
85+ state[j+1 ] = s4.y ();
86+ state[j+2 ] = s4.z ();
87+ state[j+3 ] = s4.w ();
88+ }
89+
90+ dst[t] = y;
91+ }
92+
93+ // Save final state
94+ #pragma unroll
95+ for (int i = 0 ; i < head_size; i++) {
96+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
97+ }
98+ }
99+
100+ void ggml_sycl_op_rwkv_wkv6 (ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
101+ const ggml_tensor* src1, ggml_tensor* dst) {
102+
103+ const float * k_d = (const float *)dst->src [0 ]->data ;
104+ const float * v_d = (const float *)dst->src [1 ]->data ;
105+ const float * r_d = (const float *)dst->src [2 ]->data ;
106+ const float * tf_d = (const float *)dst->src [3 ]->data ;
107+ const float * td_d = (const float *)dst->src [4 ]->data ;
108+ const float * s_d = (const float *)dst->src [5 ]->data ;
109+ float * dst_d = (float *)dst->data ;
110+
111+ const int64_t B = dst->src [5 ]->ne [1 ];
112+ const int64_t T = dst->src [0 ]->ne [3 ];
113+ const int64_t C = dst->ne [0 ];
114+ const int64_t H = dst->src [0 ]->ne [2 ];
115+
116+ GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
117+ GGML_ASSERT (C % H == 0 );
118+ GGML_ASSERT (C / H == WKV_BLOCK_SIZE);
119+
120+ dpct::queue_ptr stream = ctx.stream ();
121+
122+ // Calculate execution configuration
123+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof (float ); // For k, r, tf, td
124+ sycl::range<3 > block_dims (1 , 1 , C / H);
125+ sycl::range<3 > grid_dims (1 , 1 , B * H);
126+
127+ // Submit kernel
128+ stream->submit ([&](sycl::handler& cgh) {
129+ sycl::local_accessor<float , 1 > shared_mem_acc (shared_mem_size, cgh);
130+
131+ cgh.parallel_for (
132+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
133+ [=](sycl::nd_item<3 > item_ct1) {
134+ rwkv_wkv_f32_kernel (
135+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
136+ item_ct1, shared_mem_acc.get_pointer ()
137+ );
138+ });
139+ });
140+ }
0 commit comments