Skip to content

Commit d1340d8

Browse files
jasonjk-parkfacebook-github-bot
authored andcommitted
broadcast (pytorch#4874)
Summary: Pull Request resolved: pytorch#4874 X-link: facebookresearch/FBGEMM#1896 Add a path-thru for broadcast Reviewed By: q10 Differential Revision: D82354222 fbshipit-source-id: 1b4c8fec5d128ba43fdc36a01057d4bca29c7524
1 parent 36c506b commit d1340d8

File tree

1 file changed

+35
-0
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/comm

1 file changed

+35
-0
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,27 @@ void nccl_many2one(
225225
C10D_NCCL_CHECK(ncclGroupEnd(), "ncclGroupEnd");
226226
}
227227

228+
void nccl_broadcast(
229+
at::Tensor send,
230+
at::Tensor recv,
231+
int64_t root,
232+
int64_t comm_idx) {
233+
using namespace c10d;
234+
auto stream = at::cuda::getCurrentCUDAStream();
235+
auto& comm = *get_nccl_comm(comm_idx);
236+
237+
C10D_NCCL_CHECK(
238+
ncclBroadcast(
239+
send.data_ptr(),
240+
recv.data_ptr(),
241+
send.numel(),
242+
to_nccl_data_type(recv.scalar_type()),
243+
root,
244+
comm,
245+
stream.stream()),
246+
"ncclBroadcast");
247+
}
248+
228249
void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
229250
using namespace c10d;
230251
TORCH_CHECK(src.is_contiguous());
@@ -374,6 +395,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
374395
m.def(
375396
"nccl_many2one(Tensor(a!)[] dst, int[] src_ranks, int comm_idx=0) -> ()");
376397

398+
m.def(
399+
"nccl_broadcast(Tensor send, Tensor(a!) recv, int root, int comm_idx=0) -> ()");
400+
377401
m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");
378402

379403
m.def(
@@ -406,6 +430,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
406430
m.impl("nccl_alltoall", nccl_alltoall);
407431
m.impl("nccl_one2many", nccl_one2many);
408432
m.impl("nccl_many2one", nccl_many2one);
433+
m.impl("nccl_broadcast", nccl_broadcast);
409434
m.impl("nccl_reducescatter", nccl_reducescatter);
410435
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
411436
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
@@ -421,6 +446,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
421446
m.impl("nccl_alltoall", nccl_alltoall);
422447
m.impl("nccl_one2many", nccl_one2many);
423448
m.impl("nccl_many2one", nccl_many2one);
449+
m.impl("nccl_broadcast", nccl_broadcast);
424450
m.impl("nccl_reducescatter", nccl_reducescatter);
425451
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
426452
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
@@ -472,6 +498,14 @@ void nccl_many2one_meta(
472498
return;
473499
}
474500

501+
void nccl_broadcast_meta(
502+
at::Tensor /*send*/,
503+
at::Tensor /*recv*/,
504+
int64_t /*root*/,
505+
int64_t /*comm_idx*/) {
506+
return;
507+
}
508+
475509
void nccl_reducescatter_meta(
476510
at::Tensor /* dst */,
477511
at::Tensor /* src */,
@@ -512,6 +546,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
512546
m.impl("nccl_alltoall", nccl_alltoall_meta);
513547
m.impl("nccl_one2many", nccl_one2many_meta);
514548
m.impl("nccl_many2one", nccl_many2one_meta);
549+
m.impl("nccl_broadcast", nccl_broadcast_meta);
515550
m.impl("nccl_reducescatter", nccl_reducescatter_meta);
516551
m.impl("one_shot_car_allreduce", one_shot_car_allreduce_meta);
517552
m.impl("two_shot_car_allreduce", two_shot_car_allreduce_meta);

0 commit comments

Comments
 (0)