|  | 
|  | 1 | +// | 
|  | 2 | +// MIT license | 
|  | 3 | +// Copyright (C) 2024 Intel Corporation | 
|  | 4 | +// SPDX-License-Identifier: MIT | 
|  | 5 | +// | 
|  | 6 | + | 
|  | 7 | +// | 
|  | 8 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | 9 | +// See https://llvm.org/LICENSE.txt for license information. | 
|  | 10 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | 11 | +// | 
|  | 12 | + | 
|  | 13 | +#ifndef GGML_SYCL_GEMM_HPP | 
|  | 14 | +#define GGML_SYCL_GEMM_HPP | 
|  | 15 | + | 
|  | 16 | +#include <fstream> | 
|  | 17 | +#include <iostream> | 
|  | 18 | + | 
|  | 19 | +#include "ggml-sycl.h" | 
|  | 20 | + | 
|  | 21 | +#if GGML_SYCL_DNNL | 
|  | 22 | + | 
|  | 23 | +#include "dnnl.hpp" | 
|  | 24 | +#include "dnnl_sycl.hpp" | 
|  | 25 | + | 
|  | 26 | +class DnnlGemmWrapper { | 
|  | 27 | +public: | 
|  | 28 | +    using dt = dnnl::memory::data_type; | 
|  | 29 | +    using tag = dnnl::memory::format_tag; | 
|  | 30 | + | 
|  | 31 | +    template<typename T> | 
|  | 32 | +    static constexpr dt to_dt() { | 
|  | 33 | +        if constexpr (std::is_same_v<T, float>) return dt::f32; | 
|  | 34 | +        else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16; | 
|  | 35 | +        else static_assert(0); | 
|  | 36 | +    } | 
|  | 37 | + | 
|  | 38 | +    static inline void row_gemm(sycl::queue& q, bool a_trans, | 
|  | 39 | +        bool b_trans, int m, int n, int k, | 
|  | 40 | +        const void* a, dt at, const void* b, dt bt, void* c, dt ct) | 
|  | 41 | +    { | 
|  | 42 | +        // Get the device associated with the queue | 
|  | 43 | +        sycl::device dev = q.get_device(); | 
|  | 44 | +        // Get the context associated with the queue | 
|  | 45 | +        sycl::context ctx = q.get_context(); | 
|  | 46 | +        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); | 
|  | 47 | +        const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); | 
|  | 48 | +        dnnl::memory::dims a_dims = { m, k }; | 
|  | 49 | +        dnnl::memory::dims b_dims = { k, n }; | 
|  | 50 | +        dnnl::memory::dims c_dims = { m, n }; | 
|  | 51 | +        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); | 
|  | 52 | +        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); | 
|  | 53 | +        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); | 
|  | 54 | +        auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); | 
|  | 55 | +        auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); | 
|  | 56 | +        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); | 
|  | 57 | +        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); | 
|  | 58 | + | 
|  | 59 | +        // Create the primitive. | 
|  | 60 | +        auto matmul_prim = dnnl::matmul(matmul_pd); | 
|  | 61 | +        // Primitive arguments. | 
|  | 62 | +        std::unordered_map<int, dnnl::memory> matmul_args; | 
|  | 63 | +        matmul_args.insert({ DNNL_ARG_SRC, a_mem }); | 
|  | 64 | +        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); | 
|  | 65 | +        matmul_args.insert({ DNNL_ARG_DST, c_mem }); | 
|  | 66 | + | 
|  | 67 | +        matmul_prim.execute(stream, matmul_args); | 
|  | 68 | +    } | 
|  | 69 | + | 
|  | 70 | + | 
|  | 71 | +    static inline void row_gemm(const dnnl::stream& stream, bool a_trans, | 
|  | 72 | +        bool b_trans, int m, int n, int k, | 
|  | 73 | +        const void* a, dt at, const void* b, dt bt, void* c, dt ct) | 
|  | 74 | +    { | 
|  | 75 | +        auto const eng = stream.get_engine(); | 
|  | 76 | +        dnnl::memory::dims a_dims = { m, k }; | 
|  | 77 | +        dnnl::memory::dims b_dims = { k, n }; | 
|  | 78 | +        dnnl::memory::dims c_dims = { m, n }; | 
|  | 79 | +        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); | 
|  | 80 | +        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); | 
|  | 81 | +        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); | 
|  | 82 | +        auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); | 
|  | 83 | +        auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); | 
|  | 84 | +        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); | 
|  | 85 | +        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); | 
|  | 86 | + | 
|  | 87 | +        // Create the primitive. | 
|  | 88 | +        auto matmul_prim = dnnl::matmul(matmul_pd); | 
|  | 89 | +        // Primitive arguments. | 
|  | 90 | +        std::unordered_map<int, dnnl::memory> matmul_args; | 
|  | 91 | +        matmul_args.insert({ DNNL_ARG_SRC, a_mem }); | 
|  | 92 | +        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); | 
|  | 93 | +        matmul_args.insert({ DNNL_ARG_DST, c_mem }); | 
|  | 94 | + | 
|  | 95 | +        matmul_prim.execute(stream, matmul_args); | 
|  | 96 | +    } | 
|  | 97 | +}; | 
|  | 98 | + | 
|  | 99 | +#endif | 
|  | 100 | + | 
|  | 101 | +#endif // GGML_SYCL_GEMM_HPP | 
0 commit comments