| 
2 | 2 | #include "ggml-sycl/presets.hpp"  | 
3 | 3 | #include "ggml.h"  | 
4 | 4 | #include "element_wise.hpp"  | 
 | 5 | +#include <cstring>    | 
5 | 6 | 
 
  | 
6 | 7 | #define SYCL_GLOBAL_ID_LOOP(K, ITEM) \  | 
7 | 8 |     for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))  | 
@@ -926,6 +927,135 @@ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor  | 
926 | 927 |             ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);  | 
927 | 928 |         });  | 
928 | 929 | }  | 
 | 930 | +static inline void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {  | 
 | 931 | +    const ggml_tensor * src0 = dst->src[0];  | 
 | 932 | +    GGML_ASSERT(dst->src[1] != nullptr);  | 
 | 933 | +    const ggml_tensor * src1 = dst->src[1];  | 
 | 934 | + | 
 | 935 | +    GGML_ASSERT(src0->type == dst->type);  | 
 | 936 | +    GGML_ASSERT(src1->type == dst->type);  | 
 | 937 | +#if defined(GGML_SYCL_F16)  | 
 | 938 | +    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_I32);  | 
 | 939 | +#else  | 
 | 940 | +    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32);  | 
 | 941 | +#endif  | 
 | 942 | +    const size_t ts = ggml_type_size(dst->type);   | 
 | 943 | + | 
 | 944 | +    dpct::queue_ptr q = ctx.stream();  | 
 | 945 | +    {  | 
 | 946 | +        const bool same_type = (src0->type == dst->type);  | 
 | 947 | +        const bool src_cont  = ggml_is_contiguous(src0);  | 
 | 948 | +        const bool dst_cont  = ggml_is_contiguous(dst);  | 
 | 949 | + | 
 | 950 | +        const void *p_src0 = src0->data;  | 
 | 951 | +        void       *p_dst  = dst->data;  | 
 | 952 | + | 
 | 953 | +        auto pt_src0 = sycl::get_pointer_type((const char*)p_src0, q->get_context());  | 
 | 954 | +        auto pt_dst  = sycl::get_pointer_type((char*)p_dst,       q->get_context());  | 
 | 955 | + | 
 | 956 | +        if (same_type && src_cont && dst_cont && ggml_nelements(src0) == ggml_nelements(dst)) {  | 
 | 957 | +            const size_t bytes = ggml_nbytes(dst);  | 
 | 958 | +            if (pt_src0 != sycl::usm::alloc::unknown && pt_dst != sycl::usm::alloc::unknown) {  | 
 | 959 | +                SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(p_dst, p_src0, bytes)));  | 
 | 960 | +            } else {  | 
 | 961 | +                std::memcpy(p_dst, p_src0, bytes);  | 
 | 962 | +            }  | 
 | 963 | +        } else {  | 
 | 964 | +            const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];  | 
 | 965 | +            const size_t  db0 = dst->nb[0], db1 = dst->nb[1], db2 = dst->nb[2], db3 = dst->nb[3];  | 
 | 966 | +            const size_t  sb0 = src0->nb[0], sb1 = src0->nb[1], sb2 = src0->nb[2], sb3 = src0->nb[3];  | 
 | 967 | + | 
 | 968 | +            const size_t N  = (size_t) ggml_nelements(dst);  | 
 | 969 | +            const size_t WG = 256;  | 
 | 970 | +            const size_t NG = ((N + WG - 1) / WG) * WG;  | 
 | 971 | + | 
 | 972 | +            const size_t ge0 = (size_t) ne0;  | 
 | 973 | +            const size_t ge1 = ge0 * (size_t) ne1;  | 
 | 974 | +            const size_t ge2 = ge1 * (size_t) ne2;  | 
 | 975 | + | 
 | 976 | +            q->parallel_for(  | 
 | 977 | +                sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),  | 
 | 978 | +                [=](sycl::nd_item<1> it) {  | 
 | 979 | +                    size_t idx = it.get_global_linear_id();  | 
 | 980 | +                    if (idx >= N) return;  | 
 | 981 | + | 
 | 982 | +                    size_t i3 = idx / ge2;  size_t r2 = idx % ge2;  | 
 | 983 | +                    size_t i2 = r2 / ge1;   size_t r1 = r2 % ge1;  | 
 | 984 | +                    size_t i1 = r1 / ge0;   size_t i0 = r1 % ge0;  | 
 | 985 | + | 
 | 986 | +                    const char * s = (const char*)p_src0 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3);  | 
 | 987 | +                    char       * d = (char*)p_dst   + (i0*db0 + i1*db1 + i2*db2 + i3*db3);  | 
 | 988 | + | 
 | 989 | +                    for (size_t b = 0; b < ts; ++b) d[b] = s[b];  | 
 | 990 | +                }  | 
 | 991 | +            );  | 
 | 992 | +        }  | 
 | 993 | +    }  | 
 | 994 | + | 
 | 995 | +    {  | 
 | 996 | +        const int32_t *p = (const int32_t *) dst->op_params;  | 
 | 997 | +        const size_t nb1    = (size_t) p[0];  | 
 | 998 | +        const size_t nb2    = (size_t) p[1];  | 
 | 999 | +        const size_t nb3    = (size_t) p[2];  | 
 | 1000 | +        const size_t offset = (size_t) p[3];  | 
 | 1001 | + | 
 | 1002 | +        const void *p_src1 = src1->data;  | 
 | 1003 | +        void       *p_dst  = dst->data;  | 
 | 1004 | + | 
 | 1005 | +        const size_t sb0 = src1->nb[0], sb1 = src1->nb[1], sb2 = src1->nb[2], sb3 = src1->nb[3];  | 
 | 1006 | +        const size_t db0 = dst->nb[0];   | 
 | 1007 | +        const int64_t ne0 = src1->ne[0], ne1 = src1->ne[1], ne2 = src1->ne[2], ne3 = src1->ne[3];  | 
 | 1008 | + | 
 | 1009 | +        if (ggml_is_contiguous(src1) && db0 == ts) {  | 
 | 1010 | +            const size_t row_bytes = (size_t) ne0 * ts;  | 
 | 1011 | +            const char *s_base = (const char*) p_src1;  | 
 | 1012 | +            char       *d_base = (char*) p_dst + offset;  | 
 | 1013 | + | 
 | 1014 | +            for (int64_t i3 = 0; i3 < ne3; ++i3) {  | 
 | 1015 | +                for (int64_t i2 = 0; i2 < ne2; ++i2) {  | 
 | 1016 | +                    for (int64_t i1 = 0; i1 < ne1; ++i1) {  | 
 | 1017 | +                        const char *s_row = s_base + i1*sb1 + i2*sb2 + i3*sb3;  | 
 | 1018 | +                        char       *d_row = d_base + i1*nb1 + i2*nb2 + i3*nb3;  | 
 | 1019 | + | 
 | 1020 | +                        auto pt_s = sycl::get_pointer_type(s_row, q->get_context());  | 
 | 1021 | +                        auto pt_d = sycl::get_pointer_type(d_row, q->get_context());  | 
 | 1022 | +                        if (pt_s != sycl::usm::alloc::unknown && pt_d != sycl::usm::alloc::unknown) {  | 
 | 1023 | +                            SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(d_row, s_row, row_bytes)));  | 
 | 1024 | +                        } else {  | 
 | 1025 | +                            std::memcpy(d_row, s_row, row_bytes);  | 
 | 1026 | +                        }  | 
 | 1027 | +                    }  | 
 | 1028 | +                }  | 
 | 1029 | +            }  | 
 | 1030 | +        } else {  | 
 | 1031 | +          | 
 | 1032 | +            const size_t N  = (size_t) (ne0 * ne1 * ne2 * ne3);  | 
 | 1033 | +            const size_t WG = 256;  | 
 | 1034 | +            const size_t NG = ((N + WG - 1) / WG) * WG;  | 
 | 1035 | + | 
 | 1036 | +            const size_t ge0 = (size_t) ne0;  | 
 | 1037 | +            const size_t ge1 = ge0 * (size_t) ne1;  | 
 | 1038 | +            const size_t ge2 = ge1 * (size_t) ne2;  | 
 | 1039 | + | 
 | 1040 | +            q->parallel_for(  | 
 | 1041 | +                sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),  | 
 | 1042 | +                [=](sycl::nd_item<1> it) {  | 
 | 1043 | +                    size_t idx = it.get_global_linear_id();  | 
 | 1044 | +                    if (idx >= N) return;  | 
 | 1045 | + | 
 | 1046 | +                    size_t i3 = idx / ge2;  size_t r2 = idx % ge2;  | 
 | 1047 | +                    size_t i2 = r2 / ge1;   size_t r1 = r2 % ge1;  | 
 | 1048 | +                    size_t i1 = r1 / ge0;   size_t i0 = r1 % ge0;  | 
 | 1049 | + | 
 | 1050 | +                    const char * s = (const char*) p_src1 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3);  | 
 | 1051 | +                    char       * d = (char*) p_dst + offset + (i0*db0 + i1*nb1 + i2*nb2 + i3*nb3);  | 
 | 1052 | + | 
 | 1053 | +                    for (size_t b = 0; b < ts; ++b) d[b] = s[b];  | 
 | 1054 | +                }  | 
 | 1055 | +            );  | 
 | 1056 | +        }  | 
 | 1057 | +    }  | 
 | 1058 | +}  | 
929 | 1059 | 
 
  | 
930 | 1060 | static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {  | 
931 | 1061 |     float min_val;  | 
@@ -1124,6 +1254,11 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {  | 
1124 | 1254 |     ggml_sycl_op_pad(ctx, dst);  | 
1125 | 1255 | }  | 
1126 | 1256 | 
 
  | 
 | 1257 | +void ggml_sycl_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {  | 
 | 1258 | +    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);  | 
 | 1259 | +    ggml_sycl_op_set(ctx, dst);  | 
 | 1260 | +}  | 
 | 1261 | + | 
1127 | 1262 | void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {  | 
1128 | 1263 |     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);  | 
1129 | 1264 |     ggml_sycl_op_clamp(ctx, dst);  | 
 | 
0 commit comments