Skip to content

Commit b0f14c5

Browse files
committed
Formatting
Signed-off-by: nscipione <[email protected]>
1 parent 6b77639 commit b0f14c5

File tree

3 files changed

+71
-101
lines changed

3 files changed

+71
-101
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,12 @@ struct ggml_backend_sycl_context {
352352

353353
ggml_sycl_pool & host_pool(int device) {
354354
if (host_pools[device] == nullptr) {
355-
host_pools[device] = new_pool_for_host(stream(device,0), device);
355+
host_pools[device] = new_pool_for_host(stream(device, 0), device);
356356
}
357357
return *host_pools[device];
358358
}
359359

360-
ggml_sycl_pool & host_pool() {
361-
return host_pool(device);
362-
}
360+
ggml_sycl_pool & host_pool() { return host_pool(device); }
363361
};
364362

365363
// common device functions

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,12 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282
return device_type.str();
8383
}
8484

85-
template<typename Ts>
86-
struct matrix_info_t
87-
{
85+
template <typename Ts> struct matrix_info_t {
8886
oneapi::mkl::transpose transpose_info[2];
89-
Ts value_info[2];
90-
std::int64_t size_info[3];
91-
std::int64_t ld_info[3];
92-
std::int64_t groupsize_info;
87+
Ts value_info[2];
88+
std::int64_t size_info[3];
89+
std::int64_t ld_info[3];
90+
std::int64_t groupsize_info;
9391
};
9492

9593
namespace dpct
@@ -1737,13 +1735,10 @@ namespace dpct
17371735
};
17381736

17391737
template <class Ta, class Tb, class Tc, class Ts>
1740-
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1741-
oneapi::mkl::transpose b_trans, int m, int n, int k,
1742-
const void *alpha, const void **a, int lda,
1743-
const void **b, int ldb, const void *beta, void **c,
1744-
int ldc, int batch_size, matrix_info_t<float>* matrix_info)
1745-
{
1746-
1738+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739+
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740+
int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741+
matrix_info_t<float> * matrix_info) {
17471742
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
17481743
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
17491744

@@ -1763,19 +1758,18 @@ namespace dpct
17631758
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17641759
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
17651760
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1766-
matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info), reinterpret_cast<const Ta **>(a),
1767-
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1768-
reinterpret_cast<Ts*>(matrix_info->value_info+1), reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1769-
&(matrix_info->groupsize_info));
1761+
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1762+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1763+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1764+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17701765
#else
17711766
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17721767
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1773-
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info),
1768+
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
17741769
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1775-
matrix_info->ld_info + 1, reinterpret_cast<Ts*>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
1776-
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1770+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1771+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17771772
#endif
1778-
17791773
}
17801774

17811775
template <class Ta, class Tb, class Tc, class Ts>
@@ -2418,15 +2412,11 @@ namespace dpct
24182412
/// \param [in] ldc Leading dimension of C.
24192413
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24202414
/// \param [in] scaling_type Data type of the scaling factors.
2421-
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
2422-
oneapi::mkl::transpose b_trans, int m, int n, int k,
2423-
const void *alpha, const void *a[],
2424-
library_data_t a_type, int lda, const void *b[],
2425-
library_data_t b_type, int ldb, const void *beta,
2426-
void *c[], library_data_t c_type, int ldc,
2427-
int batch_size, library_data_t scaling_type,
2428-
matrix_info_t<float>* matrix_info)
2429-
{
2415+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2416+
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417+
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418+
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2419+
matrix_info_t<float> * matrix_info) {
24302420
std::uint64_t key =
24312421
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
24322422
switch (key)
@@ -2435,48 +2425,41 @@ namespace dpct
24352425
library_data_t::real_float, library_data_t::real_float,
24362426
library_data_t::real_float, library_data_t::real_float):
24372427
{
2438-
detail::gemm_batch_impl<float, float, float, float>(
2439-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2440-
batch_size, matrix_info);
2428+
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429+
beta, c, ldc, batch_size, matrix_info);
24412430
break;
24422431
}
24432432
case detail::get_type_combination_id(
24442433
library_data_t::real_double, library_data_t::real_double,
24452434
library_data_t::real_double, library_data_t::real_double):
24462435
{
2447-
detail::gemm_batch_impl<double, double, double, double>(
2448-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2449-
batch_size, matrix_info);
2436+
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437+
beta, c, ldc, batch_size, matrix_info);
24502438
break;
24512439
}
24522440
case detail::get_type_combination_id(
24532441
library_data_t::real_half, library_data_t::real_half,
24542442
library_data_t::real_half, library_data_t::real_half):
24552443
{
2456-
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2457-
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2458-
a, lda, b, ldb, beta, c, ldc,
2459-
batch_size, matrix_info);
2444+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24602446
break;
24612447
}
24622448
#ifdef __INTEL_MKL__
24632449
case detail::get_type_combination_id(
24642450
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
24652451
library_data_t::real_bfloat16, library_data_t::real_float):
24662452
{
2467-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2468-
oneapi::mkl::bfloat16, float>(
2469-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2470-
batch_size, matrix_info);
2453+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2454+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24712455
break;
24722456
}
24732457
case detail::get_type_combination_id(
24742458
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
24752459
library_data_t::real_float, library_data_t::real_float):
24762460
{
2477-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2478-
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2479-
b, ldb, beta, c, ldc, batch_size, matrix_info);
2461+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2462+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24802463
break;
24812464
}
24822465
#endif
@@ -2488,28 +2471,25 @@ namespace dpct
24882471
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
24892472
float beta_float =
24902473
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2491-
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2492-
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
2493-
a, lda, b, ldb, &beta_float, c, ldc,
2494-
batch_size, matrix_info);
2474+
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2475+
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476+
matrix_info);
24952477
break;
24962478
}
24972479
case detail::get_type_combination_id(
24982480
library_data_t::real_int8, library_data_t::real_int8,
24992481
library_data_t::real_float, library_data_t::real_float):
25002482
{
25012483
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2502-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2503-
batch_size, matrix_info);
2484+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25042485
break;
25052486
}
25062487
case detail::get_type_combination_id(
25072488
library_data_t::real_half, library_data_t::real_half,
25082489
library_data_t::real_float, library_data_t::real_float):
25092490
{
25102491
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2511-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2512-
batch_size, matrix_info);
2492+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25132493
break;
25142494
}
25152495
case detail::get_type_combination_id(
@@ -2523,8 +2503,7 @@ namespace dpct
25232503
sycl::half alpha_half(alpha_value);
25242504
sycl::half beta_half(beta_value);
25252505
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2526-
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2527-
batch_size, matrix_info);
2506+
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25282507
break;
25292508
}
25302509
default:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,20 +1174,20 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
11741174
};
11751175

11761176
struct ggml_sycl_pool_host : public ggml_sycl_pool {
1177-
11781177
queue_ptr qptr;
1179-
int device;
1178+
int device;
1179+
1180+
inline static int counter{ 0 };
11801181

1181-
inline static int counter{0};
11821182
struct ggml_sycl_buffer {
1183-
void * ptr = nullptr;
1183+
void * ptr = nullptr;
11841184
size_t size = 0;
11851185
};
11861186

11871187
// Set arbitrarly to 64
1188-
static constexpr int MAX_POOL_SIZE{64};
1188+
static constexpr int MAX_POOL_SIZE{ 64 };
11891189
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1190-
size_t pool_size = 0;
1190+
size_t pool_size = 0;
11911191

11921192
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
11931193

@@ -1205,32 +1205,29 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool {
12051205
}
12061206

12071207
void * alloc(size_t size, size_t * actual_size) override {
1208-
if ( counter == MAX_POOL_SIZE){
1209-
ggml_sycl_buffer b = buffer_pool[0];
1210-
size_t look_ahead_size = (size_t) (1.05 * size);
1211-
void *ptr = b.ptr;
1212-
*actual_size = b.size;
1213-
counter = 1;
1208+
if (counter == MAX_POOL_SIZE) {
1209+
ggml_sycl_buffer b = buffer_pool[0];
1210+
size_t look_ahead_size = (size_t) (1.05 * size);
1211+
void * ptr = b.ptr;
1212+
*actual_size = b.size;
1213+
counter = 1;
12141214
return ptr;
12151215
}
1216-
ggml_sycl_buffer& b = buffer_pool[counter];
1216+
ggml_sycl_buffer & b = buffer_pool[counter];
12171217

12181218
if (b.ptr == nullptr) {
1219-
void * ptr;
1220-
1221-
SYCL_CHECK(
1222-
CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host(
1223-
size, *qptr)));
1224-
if (!ptr) {
1225-
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1226-
return nullptr;
1227-
}
1228-
pool_size += size;
1229-
*actual_size = size;
1230-
counter = counter + 1;
1231-
return ptr;
1232-
}
1233-
else if (b.ptr != nullptr) {
1219+
void * ptr;
1220+
1221+
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1222+
if (!ptr) {
1223+
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1224+
return nullptr;
1225+
}
1226+
pool_size += size;
1227+
*actual_size = size;
1228+
counter = counter + 1;
1229+
return ptr;
1230+
} else if (b.ptr != nullptr) {
12341231
++counter;
12351232
b.size = size;
12361233
return b.ptr;
@@ -1241,9 +1238,9 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool {
12411238
// if the pool is not completed add the pointer to it in place of the first nullptr found.
12421239
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
12431240
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1244-
ggml_sycl_buffer& b = buffer_pool[i];
1241+
ggml_sycl_buffer & b = buffer_pool[i];
12451242
if (b.ptr == nullptr) {
1246-
b.ptr = ptr;
1243+
b.ptr = ptr;
12471244
b.size = size;
12481245
return;
12491246
}
@@ -3446,7 +3443,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
34463443

34473444
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
34483445
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3449-
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(),1);
3446+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
34503447

34513448
sycl::range<3> block_dims(1, ne12, ne13);
34523449
/*
@@ -3475,14 +3472,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
34753472
});
34763473
}
34773474
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3478-
*main_stream, oneapi::mkl::transpose::trans,
3479-
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3480-
(const void **)(ptrs_src.get() + 0 * ne23),
3481-
dpct::library_data_t::real_half, nb01 / nb00,
3482-
(const void **)(ptrs_src.get() + 1 * ne23),
3483-
dpct::library_data_t::real_half, nb11 / nb10, beta,
3484-
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3485-
cu_compute_type, matrix_info.get())));
3475+
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3476+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3477+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
3478+
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
34863479
}
34873480
}
34883481
catch (sycl::exception const &exc) {

0 commit comments

Comments
 (0)