Skip to content

Commit cb8507b

Browse files
author
Chen Xi
committed
add tensor parallelism support to SYCL
Signed-off-by: Chen Xi <[email protected]>
1 parent 7691654 commit cb8507b

File tree

5 files changed

+269
-41
lines changed

5 files changed

+269
-41
lines changed

ggml/include/ggml.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,27 @@ extern "C" {
581581
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
582582
};
583583

584+
// ggml object
585+
struct ggml_object {
586+
size_t offs;
587+
size_t size;
588+
589+
struct ggml_object * next;
590+
591+
enum ggml_object_type type;
592+
593+
char padding[4];
594+
};
595+
596+
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
597+
598+
enum tensor_parallel_mode {
599+
TENSOR_NO_CHANGE,
600+
TENSOR_SPLIT_BY_ROW,
601+
TENSOR_SPLIT_BY_COLUMN,
602+
TENSOR_KEEPED_ON_MASTER,
603+
}
604+
584605
// n-dimensional tensor
585606
struct ggml_tensor {
586607
enum ggml_type type;
@@ -616,6 +637,8 @@ extern "C" {
616637

617638
void * extra; // extra things e.g. for ggml-cuda.cu
618639

640+
enum tensor_parallel_mode split_mode = tensor_parallel_mode::TENSOR_NO_CHANGE;
641+
619642
// char padding[4];
620643
};
621644

ggml/src/ggml-sycl.cpp

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,15 @@ static void relu_f32_sycl(const float *x, float *dst, const int k,
12391239
});
12401240
}
12411241

1242+
static void allreduce_f32_sycl(const float *x, float *dst, const int k,
1243+
queue_ptr stream) {
1244+
auto dev = ccl::create_device(stream->get_device());
1245+
auto ctx = ccl::create_context(stream->get_context());
1246+
auto comm = dpct::dev_mgr::instance().create_ccl_communicator(dev, ctx);
1247+
auto ccl_stream = ccl::create_stream(*stream);
1248+
ccl::allreduce(x, dst, k, ccl::reduction::sum, comm, ccl_stream).wait();
1249+
}
1250+
12421251
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
12431252
queue_ptr stream) {
12441253
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
@@ -1736,6 +1745,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
17361745
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
17371746
}
17381747

1748+
int ggml_backend_sycl_rank() {
1749+
// use ccl rank as main gpu
1750+
return dpct::dev_mgr::instance().get_ccl_rank();
1751+
}
1752+
1753+
int ggml_backend_sycl_world_size() {
1754+
// use ccl rank as main gpu
1755+
return dpct::dev_mgr::instance().get_ccl_world_size();
1756+
}
1757+
17391758
void ggml_backend_sycl_print_sycl_devices() {
17401759
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
17411760
int device_count = dpct::dev_mgr::instance().device_count();
@@ -2270,6 +2289,21 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor
22702289
(void) src1_dd;
22712290
}
22722291

2292+
inline void ggml_sycl_op_allreduce(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2293+
ggml_tensor *dst, const float *src0_dd,
2294+
const float *src1_dd, float *dst_dd,
2295+
const queue_ptr &main_stream) {
2296+
2297+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
2298+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
2299+
2300+
allreduce_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
2301+
2302+
(void) src1;
2303+
(void) dst;
2304+
(void) src1_dd;
2305+
}
2306+
22732307
static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
22742308
const ggml_tensor *src1, ggml_tensor *dst,
22752309
const float *src0_dd, const float *src1_dd,
@@ -3179,6 +3213,13 @@ static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *
31793213
GGML_SYCL_DEBUG("call %s done\n", __func__);
31803214
}
31813215

3216+
static void ggml_sycl_allreduce(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3217+
GGML_SYCL_DEBUG("call %s\n", __func__);
3218+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_allreduce);
3219+
GGML_SYCL_DEBUG("call %s done\n", __func__);
3220+
}
3221+
3222+
31823223
static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
31833224
GGML_SYCL_DEBUG("call %s\n", __func__);
31843225
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
@@ -3530,6 +3571,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
35303571
} else {
35313572
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
35323573
}
3574+
if (src0->split_mode == tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN) {
3575+
ggml_sycl_allreduce(ctx, dst, src1, dst);
3576+
}
35333577
}
35343578

35353579

@@ -4193,6 +4237,41 @@ catch (sycl::exception const &exc) {
41934237
std::exit(1);
41944238
}
41954239

4240+
static bool split_tensor(const struct ggml_tensor * src, void* dst, void* data, int split_mode) {
4241+
int rank = ggml_backend_sycl_rank()
4242+
int world_size = ggml_backend_sycl_world_size()
4243+
auto type_traits = ggml_internal_get_type_traits(src->type);
4244+
size_t element_size = type_traits.type_size / type_traits.blck_size;
4245+
const int64_t dst_size = ggml_nelements(src) * element_size / world_size;
4246+
switch (split_mode) {
4247+
case tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN: {
4248+
const int64_t nr = ggml_nrows(src);
4249+
const int64_t nc = src->ne[0];
4250+
const int64_t ndr = nr;
4251+
const int64_t ndc = nc / world_size;
4252+
for (size_t i = 0; i < nr; ++i) {
4253+
memcpy(((char*)dst) + i * ndc * element_size,
4254+
((char*)data) + i * nc * element_size + ndc * rank * element_size, ndc * element_size);
4255+
}
4256+
} break;
4257+
case tensor_parallel_mode::TENSOR_SPLIT_BY_ROW: {
4258+
memcpy(((char*)dst), ((char*)data) + dst_size * rank, dst_size);
4259+
} break;
4260+
case tensor_parallel_mode::TENSOR_KEEPED_ON_MASTER: {
4261+
if (rank == 0) {
4262+
memcpy(((char*)dst), ((char*)data), dst_size);
4263+
} else {
4264+
memset(((char*)dst), 0, dst_size);
4265+
}
4266+
} break;
4267+
default: {
4268+
return false;
4269+
} break;
4270+
}
4271+
return true;
4272+
}
4273+
4274+
41964275
static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
41974276
ggml_tensor *tensor,
41984277
const void *data, size_t offset,
@@ -4205,7 +4284,14 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
42054284
SYCL_CHECK(
42064285
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
42074286
char* host_buf = (char*)malloc(size);
4208-
memcpy(host_buf, data, size);
4287+
4288+
if (tensor->split_mode == tensor_parallel_mode::TENSOR_NO_CHANGE) {
4289+
memcpy(host_buf, data, size);
4290+
} else {
4291+
if (!split_tensor(tensor, host_buf, data, size, tensor->split_mode)) {
4292+
std::cerr << "split tensor failed!" << std::endl;
4293+
}
4294+
}
42094295
SYCL_CHECK(
42104296
CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
42114297
.wait()));
@@ -4419,14 +4505,25 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
44194505
static bool ggml_backend_sycl_buffer_type_initialized = false;
44204506

44214507
if (!ggml_backend_sycl_buffer_type_initialized) {
4422-
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
4423-
auto & device_i = dpct::dev_mgr::instance().get_device(i);
4424-
queue_ptr stream = &(device_i.default_queue());
4425-
ggml_backend_sycl_buffer_types[i] = {
4508+
if (dpct::dev_mgr::instance().world_size() > 1) {
4509+
auto rank = dpct::dev_mgr::instance().get_rank();
4510+
auto & device_tp = dpct::dev_mgr::instance().get_device(rank);
4511+
queue_ptr stream = &(device_tp.default_queue());
4512+
// TODO(xi): buffer_types always use 0 to avoid changes on public code
4513+
ggml_backend_sycl_buffer_types[0] = {
44264514
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
4427-
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
4428-
};
4429-
}
4515+
/* .context = */ new ggml_backend_sycl_buffer_type_context{rank, GGML_SYCL_NAME + std::to_string(rank), stream},
4516+
};
4517+
} else {
4518+
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
4519+
auto & device_i = dpct::dev_mgr::instance().get_device(i);
4520+
queue_ptr stream = &(device_i.default_queue());
4521+
ggml_backend_sycl_buffer_types[i] = {
4522+
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
4523+
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
4524+
};
4525+
}
4526+
}
44304527
ggml_backend_sycl_buffer_type_initialized = true;
44314528
}
44324529
return &ggml_backend_sycl_buffer_types[device];

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <sycl/sycl.hpp>
1717
#include <sycl/half_type.hpp>
18+
#include <oneapi/ccl.hpp>
1819
#include <oneapi/mkl.hpp>
1920
#include <map>
2021

@@ -479,6 +480,8 @@ namespace dpct
479480
int _max_nd_range_size_i[3];
480481
uint32_t _device_id;
481482
std::array<unsigned char, 16> _uuid;
483+
uint32_t _rank;
484+
uint32_t _world_size;
482485
};
483486

484487
static int get_major_version(const sycl::device &dev)
@@ -870,7 +873,12 @@ namespace dpct
870873
}
871874
return -1;
872875
}
873-
876+
inline int get_ccl_rank() { return _rank; }
877+
inline int get_ccl_world_size() { return _world_size; }
878+
inline ccl::communicator create_ccl_communicator(ccl::device dev, ccl::context ctx) {
879+
return ccl::create_communicator(_world_size, _rank, dev, ctx, _kvs);
880+
881+
}
874882
inline std::string get_preferred_gpu_platform_name() {
875883
std::string result;
876884
@@ -993,6 +1001,26 @@ namespace dpct
9931001
static bool compare_backend(std::string &backend1, std::string &backend2) {
9941002
return convert_backend_index(backend1) < convert_backend_index(backend2);
9951003
}
1004+
1005+
static void init_ccl() {
1006+
ccl::init();
1007+
MPI_Init(NULL, NULL);
1008+
MPI_Comm_size(MPI_COMM_WORLD, &_world_size);
1009+
MPI_Comm_rank(MPI_COMM_WORLD, &_rank);
1010+
atexit(mpi_finalize);
1011+
ccl::kvs::address_type main_addr;
1012+
if (_rank == 0) {
1013+
_kvs = ccl::create_main_kvs();
1014+
main_addr = _kvs->get_address();
1015+
MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
1016+
}
1017+
else {
1018+
MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
1019+
_kvs = ccl::create_kvs(main_addr);
1020+
}
1021+
1022+
}
1023+
9961024
dev_mgr()
9971025
{
9981026
sycl::device default_device =
@@ -1050,6 +1078,7 @@ namespace dpct
10501078
_cpu_device = _devs.size() - 1;
10511079
}
10521080
}
1081+
init_ccl();
10531082
}
10541083
void check_id(unsigned int id) const
10551084
{
@@ -1066,6 +1095,10 @@ namespace dpct
10661095
/// thread-id to device-id map.
10671096
std::map<unsigned int, unsigned int> _thread2dev_map;
10681097
int _cpu_device = -1;
1098+
// For tensor parallelsim
1099+
int _rank = 0;
1100+
int _world_size = 1;
1101+
ccl::shared_ptr_class<ccl::kvs> _kvs;
10691102
};
10701103
10711104
static inline sycl::queue &get_default_queue()

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ extern "C" {
204204
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
205205
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
206206
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
207+
LLAMA_SPLIT_MODE_TENSOR = 3, // split tensors across GPUs
207208
};
208209

209210
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)

0 commit comments

Comments
 (0)