We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6a1e977 commit a749ba7Copy full SHA for a749ba7
ggml/src/ggml-sycl/wkv6.cpp
@@ -59,14 +59,15 @@ static void rwkv_wkv_f32_kernel(
59
float y = 0;
60
61
// Process in chunks of 4 for better vectorization
62
+ sycl::float4 k4, r4, tf4, td4, s4, kv4;
63
#pragma unroll
64
for (int j = 0; j < head_size; j += 4) {
65
// Load data in vec4 chunks
- sycl::float4 k4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
66
- sycl::float4 r4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
67
- sycl::float4 tf4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
68
- sycl::float4 td4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
69
- sycl::float4 s4(state[j], state[j+1], state[j+2], state[j+3]);
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
70
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
71
72
// Compute key-value product
73
sycl::float4 kv4 = k4 * _v;
0 commit comments