1313#ifndef GGML_SYCL_GEMM_HPP
1414#define GGML_SYCL_GEMM_HPP
1515
16- #include < fstream>
17- #include < iostream>
18-
1916#include " ggml-sycl.h"
2017
2118#if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ class DnnlGemmWrapper {
3532 else static_assert (0 );
3633 }
3734
38- static inline void row_gemm (sycl::queue& q, bool a_trans,
39- bool b_trans, int m, int n, int k,
40- const void * a, dt at, const void * b, dt bt, void * c, dt ct)
41- {
42- // Get the device associated with the queue
43- sycl::device dev = q.get_device ();
44- // Get the context associated with the queue
45- sycl::context ctx = q.get_context ();
46- const dnnl::engine eng = dnnl::sycl_interop::make_engine (dev, ctx);
47- const dnnl::stream stream = dnnl::sycl_interop::make_stream (eng, q);
35+ static inline void row_gemm (ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
37+ auto stream = ctx.stream_dnnl (q);
38+ auto eng = ctx.engine_dnnl (q);
4839 dnnl::memory::dims a_dims = { m, k };
4940 dnnl::memory::dims b_dims = { k, n };
5041 dnnl::memory::dims c_dims = { m, n };
5142 const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
5243 const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
53- const auto c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
54- auto a_mem = dnnl::memory (a_in_md, eng, const_cast <void *>(a));
55- auto b_mem = dnnl::memory (b_in_md, eng, const_cast <void *>(b));
56- auto matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
57- auto c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
44+ const auto c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
5845
59- // Create the primitive.
60- auto matmul_prim = dnnl::matmul (matmul_pd);
61- // Primitive arguments.
62- std::unordered_map<int , dnnl::memory> matmul_args;
63- matmul_args.insert ({ DNNL_ARG_SRC, a_mem });
64- matmul_args.insert ({ DNNL_ARG_WEIGHTS, b_mem });
65- matmul_args.insert ({ DNNL_ARG_DST, c_mem });
46+ dnnl::primitive_attr primitive_attr;
47+ primitive_attr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
6648
67- matmul_prim.execute (stream, matmul_args);
68- }
69-
70-
71- static inline void row_gemm (const dnnl::stream& stream, bool a_trans,
72- bool b_trans, int m, int n, int k,
73- const void * a, dt at, const void * b, dt bt, void * c, dt ct)
74- {
75- auto const eng = stream.get_engine ();
76- dnnl::memory::dims a_dims = { m, k };
77- dnnl::memory::dims b_dims = { k, n };
78- dnnl::memory::dims c_dims = { m, n };
79- const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
80- const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
81- const auto c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
8249 auto a_mem = dnnl::memory (a_in_md, eng, const_cast <void *>(a));
8350 auto b_mem = dnnl::memory (b_in_md, eng, const_cast <void *>(b));
84- auto matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
51+ auto matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md, primitive_attr );
8552 auto c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
8653
87- // Create the primitive.
54+ auto scratchpad_md = matmul_pd.scratchpad_desc ();
55+ auto scratchpad_mem = ctx.get_scratchpad_mem (scratchpad_md, eng, q);
8856 auto matmul_prim = dnnl::matmul (matmul_pd);
89- // Primitive arguments.
57+
9058 std::unordered_map<int , dnnl::memory> matmul_args;
9159 matmul_args.insert ({ DNNL_ARG_SRC, a_mem });
9260 matmul_args.insert ({ DNNL_ARG_WEIGHTS, b_mem });
9361 matmul_args.insert ({ DNNL_ARG_DST, c_mem });
62+ matmul_args.insert ({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
9463
9564 matmul_prim.execute (stream, matmul_args);
9665 }
0 commit comments