@@ -32,16 +32,30 @@ class DnnlGemmWrapper {
3232 else static_assert (0 );
3333 }
3434
35- static void row_gemm (ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36- const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q,
37- dnnl_dim_t batches = 1 ) {
35+ // matrix A has m rows, k columns
36+ // matrix B has k rows, n columns
37+ // nra - number of elements to skip when moving into next row in A
38+ // nrb - number of elements to skip when moving into next row in B
39+ // nca - number of elements to skip when moving into next column in A
40+ // ncb - number of elements to skip when moving into next column in B
41+ // stride_a - number of elements to skip when moving to next A matrix
42+ // stride_b - number of elements to skip when moving to next B matrix
43+ // batches - number of A matrices, equal to number of B matrices
44+ static void gemm (ggml_backend_sycl_context & ctx, int m, int n, int k,
45+ const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
46+ const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
47+ void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches) {
48+
3849 auto stream = ctx.stream_dnnl (q);
3950 auto eng = ctx.engine_dnnl (q);
4051 dnnl::memory::dims a_dims = { batches, m, k };
4152 dnnl::memory::dims b_dims = { batches, k, n };
4253 dnnl::memory::dims c_dims = { batches, m, n };
43- const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::acb : tag::abc);
44- const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::acb : tag::abc);
54+ dnnl::memory::dims a_strides = { stride_a, nra, nca };
55+ dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
56+
57+ const auto a_in_md = dnnl::memory::desc (a_dims, at, a_strides);
58+ const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_strides);
4559 const auto c_md = dnnl::memory::desc (c_dims, ct, tag::abc);
4660
4761 dnnl::primitive_attr primitive_attr;
@@ -64,6 +78,15 @@ class DnnlGemmWrapper {
6478
6579 matmul_prim.execute (stream, matmul_args);
6680 }
81+
82+ // matrices A and B are column major, both having k rows
83+ // matrix A has m column, matrix B has n columns
84+ // output: column major matrix C = A transposed * B
85+ static void row_gemm (ggml_backend_sycl_context & ctx, int m, int n, int k,
86+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
87+
88+ gemm (ctx, m, n, k, a, at, k, 1 , k * m, b, bt, 1 , k, n * k, c, ct, q, 1 );
89+ }
6790};
6891
6992#endif
0 commit comments