Skip to content

Commit 11bc77b

Browse files
committed
sycl: cleanup oneDNN related code
1 parent 9f7add1 commit 11bc77b

File tree

5 files changed

+76
-64
lines changed

5 files changed

+76
-64
lines changed

docs/backend/SYCL.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB
227227
cmake --build buildWithCublas --config Release
228228
```
229229

230+
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
231+
232+
```sh
233+
git clone https://github.com/oneapi-src/oneDNN.git
234+
cd oneDNN
235+
cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
236+
cmake --build build-nvidia --config Release
237+
```
238+
230239
- **Adding support to AMD GPUs**
231240

232241
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
@@ -317,10 +326,10 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
317326
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
318327

319328
# Option 1: Use FP32 (recommended for better performance in most cases)
320-
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
329+
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
321330

322331
# Option 2: Use FP16
323-
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON
332+
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
324333

325334
# build all binary
326335
cmake --build build --config Release -j -v

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@ ggml_add_backend_library(ggml-sycl
2121
../../include/ggml-sycl.h
2222
)
2323

24+
find_package(DNNL)
25+
set(GGML_SYCL_DNNL 0)
26+
if(DNNL_FOUND)
27+
get_target_property(CONFIG DNNL::dnnl IMPORTED_CONFIGURATIONS)
28+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
29+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
30+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
31+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
32+
set(GGML_SYCL_DNNL 1)
33+
else()
34+
message(WARNING
35+
"oneDNN must be compiled for the same target as llama.cpp.
36+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
37+
Disabling oneDNN support.")
38+
endif()
39+
else()
40+
message(STATUS "oneDNN not found, disabling oneDNN support")
41+
endif()
42+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
43+
2444
if (GGML_SYCL_F16)
2545
if (GGML_SYCL_TARGET STREQUAL "AMD")
2646
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
@@ -46,18 +66,6 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
4666
file(GLOB GGML_SOURCES_SYCL "*.cpp")
4767
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
4868

49-
find_package(DNNL)
50-
message("-- DNNL found:" ${DNNL_FOUND})
51-
52-
if (GGML_SYCL_TARGET STREQUAL "INTEL")
53-
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
54-
else()
55-
add_compile_definitions(GGML_SYCL_DNNL=0)
56-
endif()
57-
58-
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
59-
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
60-
endif()
6169

6270
if (WIN32)
6371
find_package(IntelSYCL REQUIRED)

ggml/src/ggml-sycl/common.hpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ static size_t g_scratch_offset = 0;
163163
int get_current_device_id();
164164

165165
inline dpct::err0 ggml_sycl_set_device(const int device) try {
166-
167166
int current_device_id;
168167
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
169168

@@ -229,6 +228,14 @@ struct ggml_sycl_pool_alloc {
229228
}
230229
}
231230

231+
T * realloc(size_t size) {
232+
GGML_ASSERT(pool != nullptr);
233+
if (ptr)
234+
pool->free(ptr, actual_size);
235+
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
236+
return ptr;
237+
}
238+
232239
// size is in number of elements
233240
T * alloc(size_t size) {
234241
GGML_ASSERT(pool != nullptr);
@@ -328,10 +335,29 @@ struct ggml_backend_sycl_context {
328335
dnnl::stream stream_dnnl() {
329336
return stream_dnnl(device, 0);
330337
}
338+
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
339+
const dnnl::engine & eng, const queue_ptr q) {
340+
ggml_sycl_pool_alloc<uint8_t> * pool;
341+
auto it = scratchpad_map.find(q);
342+
if (it == scratchpad_map.end()) {
343+
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
344+
pool = scratchpad_map[q].get();
345+
} else {
346+
pool = it->second.get();
347+
}
348+
349+
size_t scratchpad_size = scratchpad_md.get_size();
350+
if (scratchpad_size > pool->actual_size) {
351+
pool->realloc(scratchpad_size);
352+
}
353+
void * mem_ptr = pool->get();
354+
return dnnl::memory(scratchpad_md, eng, mem_ptr);
355+
}
331356
#endif
332357

333358
// pool
334359
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
360+
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
335361

336362
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
337363

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
#ifndef GGML_SYCL_GEMM_HPP
1414
#define GGML_SYCL_GEMM_HPP
1515

16-
#include <fstream>
17-
#include <iostream>
18-
1916
#include "ggml-sycl.h"
2017

2118
#if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ class DnnlGemmWrapper {
3532
else static_assert(0);
3633
}
3734

38-
static inline void row_gemm(sycl::queue& q, bool a_trans,
39-
bool b_trans, int m, int n, int k,
40-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41-
{
42-
// Get the device associated with the queue
43-
sycl::device dev = q.get_device();
44-
// Get the context associated with the queue
45-
sycl::context ctx = q.get_context();
46-
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47-
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
35+
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
37+
auto stream = ctx.stream_dnnl(q);
38+
auto eng = ctx.engine_dnnl(q);
4839
dnnl::memory::dims a_dims = { m, k };
4940
dnnl::memory::dims b_dims = { k, n };
5041
dnnl::memory::dims c_dims = { m, n };
5142
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
5243
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54-
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
55-
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
56-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57-
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
44+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
5845

59-
// Create the primitive.
60-
auto matmul_prim = dnnl::matmul(matmul_pd);
61-
// Primitive arguments.
62-
std::unordered_map<int, dnnl::memory> matmul_args;
63-
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64-
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65-
matmul_args.insert({ DNNL_ARG_DST, c_mem });
46+
dnnl::primitive_attr primitive_attr;
47+
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
6648

67-
matmul_prim.execute(stream, matmul_args);
68-
}
69-
70-
71-
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72-
bool b_trans, int m, int n, int k,
73-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74-
{
75-
auto const eng = stream.get_engine();
76-
dnnl::memory::dims a_dims = { m, k };
77-
dnnl::memory::dims b_dims = { k, n };
78-
dnnl::memory::dims c_dims = { m, n };
79-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
8249
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
8350
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
84-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
51+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
8552
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
8653

87-
// Create the primitive.
54+
auto scratchpad_md = matmul_pd.scratchpad_desc();
55+
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
8856
auto matmul_prim = dnnl::matmul(matmul_pd);
89-
// Primitive arguments.
57+
9058
std::unordered_map<int, dnnl::memory> matmul_args;
9159
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
9260
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
9361
matmul_args.insert({ DNNL_ARG_DST, c_mem });
62+
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
9463

9564
matmul_prim.execute(stream, matmul_args);
9665
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,9 +2629,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
26292629
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
26302630
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
26312631
#else
2632-
auto dnnl_stream = ctx.stream_dnnl(stream);
2633-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2634-
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2632+
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2633+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2634+
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
26352635
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
26362636
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
26372637
#endif
@@ -2670,9 +2670,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
26702670
dst_dd_i, ldc)));
26712671
# endif
26722672
#else
2673-
auto dnnl_stream = ctx.stream_dnnl(stream);
2674-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2675-
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2673+
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2674+
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2675+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
26762676
#endif
26772677
}
26782678
GGML_UNUSED(dst);

0 commit comments

Comments
 (0)