1313#ifndef  GGML_SYCL_GEMM_HPP
1414#define  GGML_SYCL_GEMM_HPP 
1515
16- #include  < fstream> 
17- #include  < iostream> 
18- 
1916#include  " ggml-sycl.h" 
2017
2118#if  GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ class DnnlGemmWrapper {
3532        else  static_assert (0 );
3633    }
3734
38-     static  inline  void  row_gemm (sycl::queue& q, bool  a_trans,
39-         bool  b_trans, int  m, int  n, int  k,
40-         const  void * a, dt at, const  void * b, dt bt, void * c, dt ct)
41-     {
42-         //  Get the device associated with the queue
43-         sycl::device dev = q.get_device ();
44-         //  Get the context associated with the queue
45-         sycl::context ctx = q.get_context ();
46-         const  dnnl::engine eng = dnnl::sycl_interop::make_engine (dev, ctx);
47-         const  dnnl::stream stream = dnnl::sycl_interop::make_stream (eng, q);
35+     static  inline  void  row_gemm (ggml_backend_sycl_context & ctx, bool  a_trans, bool  b_trans, int  m, int  n, int  k,
36+                                 const  void  * a, dt at, const  void  * b, dt bt, void  * c, dt ct, const  queue_ptr & q) {
37+         auto  stream = ctx.stream_dnnl (q);
38+         auto  eng = ctx.engine_dnnl (q);
4839        dnnl::memory::dims a_dims = { m, k };
4940        dnnl::memory::dims b_dims = { k, n };
5041        dnnl::memory::dims c_dims = { m, n };
5142        const  auto  a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
5243        const  auto  b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
53-         const  auto  c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
54-         auto  a_mem = dnnl::memory (a_in_md, eng, const_cast <void *>(a));
55-         auto  b_mem = dnnl::memory (b_in_md, eng, const_cast <void *>(b));
56-         auto  matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
57-         auto  c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
44+         const  auto  c_md    = dnnl::memory::desc (c_dims, ct, tag::ab);
5845
59-         //  Create the primitive.
60-         auto  matmul_prim = dnnl::matmul (matmul_pd);
61-         //  Primitive arguments.
62-         std::unordered_map<int , dnnl::memory> matmul_args;
63-         matmul_args.insert ({ DNNL_ARG_SRC, a_mem });
64-         matmul_args.insert ({ DNNL_ARG_WEIGHTS, b_mem });
65-         matmul_args.insert ({ DNNL_ARG_DST, c_mem });
46+         dnnl::primitive_attr primitive_attr;
47+         primitive_attr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
6648
67-         matmul_prim.execute (stream, matmul_args);
68-     }
69- 
70- 
71-     static  inline  void  row_gemm (const  dnnl::stream& stream, bool  a_trans,
72-         bool  b_trans, int  m, int  n, int  k,
73-         const  void * a, dt at, const  void * b, dt bt, void * c, dt ct)
74-     {
75-         auto  const  eng = stream.get_engine ();
76-         dnnl::memory::dims a_dims = { m, k };
77-         dnnl::memory::dims b_dims = { k, n };
78-         dnnl::memory::dims c_dims = { m, n };
79-         const  auto  a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
80-         const  auto  b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
81-         const  auto  c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
8249        auto  a_mem = dnnl::memory (a_in_md, eng, const_cast <void *>(a));
8350        auto  b_mem = dnnl::memory (b_in_md, eng, const_cast <void *>(b));
84-         auto  matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
51+         auto  matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md, primitive_attr );
8552        auto  c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
8653
87-         //  Create the primitive.
54+         auto  scratchpad_md = matmul_pd.scratchpad_desc ();
55+         auto  scratchpad_mem = ctx.get_scratchpad_mem (scratchpad_md, eng, q);
8856        auto  matmul_prim = dnnl::matmul (matmul_pd);
89-          //  Primitive arguments. 
57+ 
9058        std::unordered_map<int , dnnl::memory> matmul_args;
9159        matmul_args.insert ({ DNNL_ARG_SRC, a_mem });
9260        matmul_args.insert ({ DNNL_ARG_WEIGHTS, b_mem });
9361        matmul_args.insert ({ DNNL_ARG_DST, c_mem });
62+         matmul_args.insert ({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
9463
9564        matmul_prim.execute (stream, matmul_args);
9665    }
0 commit comments