@@ -1239,6 +1239,15 @@ static void relu_f32_sycl(const float *x, float *dst, const int k,
12391239 });
12401240}
12411241
1242+ static void allreduce_f32_sycl (const float *x, float *dst, const int k,
1243+ queue_ptr stream) {
1244+ auto dev = ccl::create_device (stream->get_device ());
1245+ auto ctx = ccl::create_context (stream->get_context ());
1246+ auto comm = dpct::dev_mgr::instance ().create_ccl_communicator (dev, ctx);
1247+ auto ccl_stream = ccl::create_stream (*stream);
1248+ ccl::allreduce (x, dst, k, ccl::reduction::sum, comm, ccl_stream).wait ();
1249+ }
1250+
12421251static void hardsigmoid_f32_sycl (const float *x, float *dst, const int k,
12431252 queue_ptr stream) {
12441253 const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1 ) / SYCL_HARDSIGMOID_BLOCK_SIZE;
@@ -1736,6 +1745,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
17361745 global_mem_size, device.get_info <sycl::info::device::driver_version>().c_str ());
17371746}
17381747
1748+ int ggml_backend_sycl_rank () {
1749+ // use ccl rank as main gpu
1750+ return dpct::dev_mgr::instance ().get_ccl_rank ();
1751+ }
1752+
1753+ int ggml_backend_sycl_world_size () {
1754+ // use ccl rank as main gpu
1755+ return dpct::dev_mgr::instance ().get_ccl_world_size ();
1756+ }
1757+
17391758void ggml_backend_sycl_print_sycl_devices () {
17401759 GGML_SYCL_DEBUG (" [SYCL] call ggml_backend_sycl_print_sycl_devices\n " );
17411760 int device_count = dpct::dev_mgr::instance ().device_count ();
@@ -2270,6 +2289,21 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor
22702289 (void ) src1_dd;
22712290}
22722291
2292+ inline void ggml_sycl_op_allreduce (ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2293+ ggml_tensor *dst, const float *src0_dd,
2294+ const float *src1_dd, float *dst_dd,
2295+ const queue_ptr &main_stream) {
2296+
2297+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
2298+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
2299+
2300+ allreduce_f32_sycl (src0_dd, dst_dd, ggml_nelements (src0), main_stream);
2301+
2302+ (void ) src1;
2303+ (void ) dst;
2304+ (void ) src1_dd;
2305+ }
2306+
22732307static void ggml_sycl_op_hardsigmoid (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
22742308 const ggml_tensor *src1, ggml_tensor *dst,
22752309 const float *src0_dd, const float *src1_dd,
@@ -3179,6 +3213,13 @@ static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *
31793213 GGML_SYCL_DEBUG (" call %s done\n " , __func__);
31803214}
31813215
3216+ static void ggml_sycl_allreduce (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3217+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
3218+ ggml_sycl_op_flatten (ctx, src0, src1, dst, ggml_sycl_op_allreduce);
3219+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
3220+ }
3221+
3222+
31823223static void ggml_sycl_hardsigmoid (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
31833224 GGML_SYCL_DEBUG (" call %s\n " , __func__);
31843225 ggml_sycl_op_flatten (ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
@@ -3530,6 +3571,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
35303571 } else {
35313572 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false );
35323573 }
3574+ if (src0->split_mode == tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN) {
3575+ ggml_sycl_allreduce (ctx, dst, src1, dst);
3576+ }
35333577}
35343578
35353579
@@ -4193,6 +4237,41 @@ catch (sycl::exception const &exc) {
41934237 std::exit (1 );
41944238}
41954239
4240+ static bool split_tensor (const struct ggml_tensor * src, void * dst, void * data, int split_mode) {
4241+ int rank = ggml_backend_sycl_rank ()
4242+ int world_size = ggml_backend_sycl_world_size ()
4243+ auto type_traits = ggml_internal_get_type_traits (src->type );
4244+ size_t element_size = type_traits.type_size / type_traits.blck_size ;
4245+ const int64_t dst_size = ggml_nelements (src) * element_size / world_size;
4246+ switch (split_mode) {
4247+ case tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN: {
4248+ const int64_t nr = ggml_nrows (src);
4249+ const int64_t nc = src->ne [0 ];
4250+ const int64_t ndr = nr;
4251+ const int64_t ndc = nc / world_size;
4252+ for (size_t i = 0 ; i < nr; ++i) {
4253+ memcpy (((char *)dst) + i * ndc * element_size,
4254+ ((char *)data) + i * nc * element_size + ndc * rank * element_size, ndc * element_size);
4255+ }
4256+ } break ;
4257+ case tensor_parallel_mode::TENSOR_SPLIT_BY_ROW: {
4258+ memcpy (((char *)dst), ((char *)data) + dst_size * rank, dst_size);
4259+ } break ;
4260+ case tensor_parallel_mode::TENSOR_KEEPED_ON_MASTER: {
4261+ if (rank == 0 ) {
4262+ memcpy (((char *)dst), ((char *)data), dst_size);
4263+ } else {
4264+ memset (((char *)dst), 0 , dst_size);
4265+ }
4266+ } break ;
4267+ default : {
4268+ return false ;
4269+ } break ;
4270+ }
4271+ return true ;
4272+ }
4273+
4274+
41964275static void ggml_backend_sycl_buffer_set_tensor (ggml_backend_buffer_t buffer,
41974276 ggml_tensor *tensor,
41984277 const void *data, size_t offset,
@@ -4205,7 +4284,14 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
42054284 SYCL_CHECK (
42064285 CHECK_TRY_ERROR (dpct::dev_mgr::instance ().get_device (ctx->device ).queues_wait_and_throw ()));
42074286 char * host_buf = (char *)malloc (size);
4208- memcpy (host_buf, data, size);
4287+
4288+ if (tensor->split_mode == tensor_parallel_mode::TENSOR_NO_CHANGE) {
4289+ memcpy (host_buf, data, size);
4290+ } else {
4291+ if (!split_tensor (tensor, host_buf, data, size, tensor->split_mode )) {
4292+ std::cerr << " split tensor failed!" << std::endl;
4293+ }
4294+ }
42094295 SYCL_CHECK (
42104296 CHECK_TRY_ERROR ((*stream).memcpy ((char *)tensor->data + offset, host_buf, size)
42114297 .wait ()));
@@ -4419,14 +4505,25 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
44194505 static bool ggml_backend_sycl_buffer_type_initialized = false ;
44204506
44214507 if (!ggml_backend_sycl_buffer_type_initialized) {
4422- for (int i = 0 ; i < ggml_sycl_info ().device_count ; i++) {
4423- auto & device_i = dpct::dev_mgr::instance ().get_device (i);
4424- queue_ptr stream = &(device_i.default_queue ());
4425- ggml_backend_sycl_buffer_types[i] = {
4508+ if (dpct::dev_mgr::instance ().world_size () > 1 ) {
4509+ auto rank = dpct::dev_mgr::instance ().get_rank ();
4510+ auto & device_tp = dpct::dev_mgr::instance ().get_device (rank);
4511+ queue_ptr stream = &(device_tp.default_queue ());
4512+ // TODO(xi): buffer_types always use 0 to avoid changes on public code
4513+ ggml_backend_sycl_buffer_types[0 ] = {
44264514 /* .iface = */ ggml_backend_sycl_buffer_type_interface,
4427- /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string (i), stream},
4428- };
4429- }
4515+ /* .context = */ new ggml_backend_sycl_buffer_type_context{rank, GGML_SYCL_NAME + std::to_string (rank), stream},
4516+ };
4517+ } else {
4518+ for (int i = 0 ; i < ggml_sycl_info ().device_count ; i++) {
4519+ auto & device_i = dpct::dev_mgr::instance ().get_device (i);
4520+ queue_ptr stream = &(device_i.default_queue ());
4521+ ggml_backend_sycl_buffer_types[i] = {
4522+ /* .iface = */ ggml_backend_sycl_buffer_type_interface,
4523+ /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string (i), stream},
4524+ };
4525+ }
4526+ }
44304527 ggml_backend_sycl_buffer_type_initialized = true ;
44314528 }
44324529 return &ggml_backend_sycl_buffer_types[device];
0 commit comments