| 
 | 1 | +#include "roll.hpp"  | 
 | 2 | +#include "common.hpp"  | 
 | 3 | + | 
 | 4 | +using namespace sycl;  | 
 | 5 | + | 
 | 6 | +static inline int wrap_add(int i, int shift, int n) {  | 
 | 7 | + | 
 | 8 | +    int s = i + shift;  | 
 | 9 | +    return (s >= n) ? (s - n) : s;  | 
 | 10 | +}  | 
 | 11 | + | 
 | 12 | +static void kernel_roll_fused_i0_i1(  | 
 | 13 | +    queue &q,  | 
 | 14 | +    const float *src_d,  | 
 | 15 | +    float *dst_d,  | 
 | 16 | +    int ne0, int ne1, int ne2, int ne3,  | 
 | 17 | +    int sh0, int sh1, int sh2, int sh3)  | 
 | 18 | +{  | 
 | 19 | +    if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +    const int stride1 = ne0;  | 
 | 23 | +    const int stride2 = ne0 * ne1;  | 
 | 24 | +    const int stride3 = ne0 * ne1 * ne2;  | 
 | 25 | + | 
 | 26 | + | 
 | 27 | +    const int shNe0 = (ne0 - sh0) % ne0;  | 
 | 28 | +    const int shNe1 = (ne1 - sh1) % ne1;  | 
 | 29 | +    const int shNe2 = (ne2 - sh2) % ne2;  | 
 | 30 | +    const int shNe3 = (ne3 - sh3) % ne3;  | 
 | 31 | + | 
 | 32 | + | 
 | 33 | +    const size_t g0 = (size_t) ne3;  | 
 | 34 | +    const size_t g1 = (size_t) ne2;  | 
 | 35 | +    const size_t g2 = (size_t) (ne1 * ne0);  | 
 | 36 | + | 
 | 37 | +    const range<3> global{ g0, g1, g2 };  | 
 | 38 | + | 
 | 39 | +    q.submit([&](handler &h) {  | 
 | 40 | +        h.parallel_for(global, [=](id<3> idx) {  | 
 | 41 | +            const int i3 = (int) idx[0];  | 
 | 42 | +            const int i2 = (int) idx[1];  | 
 | 43 | + | 
 | 44 | +            const int fused = (int) idx[2];  | 
 | 45 | +            const int i1 = fused / ne0;  | 
 | 46 | +            const int i0 = fused - i1 * ne0;  // fused % ne0  | 
 | 47 | + | 
 | 48 | + | 
 | 49 | +            const int idx_dst = i0  | 
 | 50 | +                              + i1 * stride1  | 
 | 51 | +                              + i2 * stride2  | 
 | 52 | +                              + i3 * stride3;  | 
 | 53 | + | 
 | 54 | + | 
 | 55 | +            const int s0 = wrap_add(i0, shNe0, ne0);  | 
 | 56 | +            const int s1 = wrap_add(i1, shNe1, ne1);  | 
 | 57 | +            const int s2 = wrap_add(i2, shNe2, ne2);  | 
 | 58 | +            const int s3 = wrap_add(i3, shNe3, ne3);  | 
 | 59 | + | 
 | 60 | +            const int idx_src = s0  | 
 | 61 | +                              + s1 * stride1  | 
 | 62 | +                              + s2 * stride2  | 
 | 63 | +                              + s3 * stride3;  | 
 | 64 | + | 
 | 65 | +            dst_d[idx_dst] = src_d[idx_src];  | 
 | 66 | +        });  | 
 | 67 | +    });  | 
 | 68 | +}  | 
 | 69 | + | 
 | 70 | +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {  | 
 | 71 | +    GGML_ASSERT(dst->type == GGML_TYPE_F32);  | 
 | 72 | + | 
 | 73 | +    const ggml_tensor *src = dst->src[0];  | 
 | 74 | +    GGML_ASSERT(src && src->type == GGML_TYPE_F32);  | 
 | 75 | + | 
 | 76 | +    const int ne0 = (int) dst->ne[0];  | 
 | 77 | +    const int ne1 = (int) dst->ne[1];  | 
 | 78 | +    const int ne2 = (int) dst->ne[2];  | 
 | 79 | +    const int ne3 = (int) dst->ne[3];  | 
 | 80 | + | 
 | 81 | +    const int32_t *params = (const int32_t *) dst->op_params;  | 
 | 82 | +    int shift0 = params[0];  | 
 | 83 | +    int shift1 = params[1];  | 
 | 84 | +    int shift2 = params[2];  | 
 | 85 | +    int shift3 = params[3];  | 
 | 86 | + | 
 | 87 | + | 
 | 88 | +    if ((shift0 | shift1 | shift2 | shift3) == 0) {  | 
 | 89 | +        const size_t nb = ggml_nbytes(src);  | 
 | 90 | +        queue *q = ctx.stream();  | 
 | 91 | +        SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));  | 
 | 92 | +        return;  | 
 | 93 | +    }  | 
 | 94 | + | 
 | 95 | +    auto norm = [](int sh, int n) -> int {  | 
 | 96 | +        if (n <= 0) return 0;  | 
 | 97 | +        sh %= n;  | 
 | 98 | +        if (sh < 0) sh += n;  | 
 | 99 | +        return sh;  | 
 | 100 | +    };  | 
 | 101 | +    shift0 = norm(shift0, ne0);  | 
 | 102 | +    shift1 = norm(shift1, ne1);  | 
 | 103 | +    shift2 = norm(shift2, ne2);  | 
 | 104 | +    shift3 = norm(shift3, ne3);  | 
 | 105 | + | 
 | 106 | +    try {  | 
 | 107 | +        queue *q = ctx.stream();  | 
 | 108 | + | 
 | 109 | +        const float *src_d = (const float *) src->data;  | 
 | 110 | +        float *dst_d = (float *) dst->data;  | 
 | 111 | +        GGML_ASSERT(src_d && dst_d);  | 
 | 112 | + | 
 | 113 | +        kernel_roll_fused_i0_i1(  | 
 | 114 | +            *q, src_d, dst_d,  | 
 | 115 | +            ne0, ne1, ne2, ne3,  | 
 | 116 | +            shift0, shift1, shift2, shift3  | 
 | 117 | +        );  | 
 | 118 | +    } catch (const std::exception &e) {  | 
 | 119 | +        std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());  | 
 | 120 | +        throw;  | 
 | 121 | +    }  | 
 | 122 | +}  | 
0 commit comments