@@ -225,6 +225,27 @@ void nccl_many2one(
225
225
C10D_NCCL_CHECK (ncclGroupEnd (), " ncclGroupEnd" );
226
226
}
227
227
228
+ void nccl_broadcast (
229
+ at::Tensor send,
230
+ at::Tensor recv,
231
+ int64_t root,
232
+ int64_t comm_idx) {
233
+ using namespace c10d ;
234
+ auto stream = at::cuda::getCurrentCUDAStream ();
235
+ auto & comm = *get_nccl_comm (comm_idx);
236
+
237
+ C10D_NCCL_CHECK (
238
+ ncclBroadcast (
239
+ send.data_ptr (),
240
+ recv.data_ptr (),
241
+ send.numel (),
242
+ to_nccl_data_type (recv.scalar_type ()),
243
+ root,
244
+ comm,
245
+ stream.stream ()),
246
+ " ncclBroadcast" );
247
+ }
248
+
228
249
void nccl_reducescatter (at::Tensor dst, at::Tensor src, int64_t comm_idx) {
229
250
using namespace c10d ;
230
251
TORCH_CHECK (src.is_contiguous ());
@@ -374,6 +395,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
374
395
m.def (
375
396
" nccl_many2one(Tensor(a!)[] dst, int[] src_ranks, int comm_idx=0) -> ()" );
376
397
398
+ m.def (
399
+ " nccl_broadcast(Tensor send, Tensor(a!) recv, int root, int comm_idx=0) -> ()" );
400
+
377
401
m.def (" nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()" );
378
402
379
403
m.def (
@@ -406,6 +430,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
406
430
m.impl (" nccl_alltoall" , nccl_alltoall);
407
431
m.impl (" nccl_one2many" , nccl_one2many);
408
432
m.impl (" nccl_many2one" , nccl_many2one);
433
+ m.impl (" nccl_broadcast" , nccl_broadcast);
409
434
m.impl (" nccl_reducescatter" , nccl_reducescatter);
410
435
m.impl (" one_shot_car_allreduce" , one_shot_car_allreduce);
411
436
m.impl (" two_shot_car_allreduce" , two_shot_car_allreduce);
@@ -421,6 +446,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
421
446
m.impl (" nccl_alltoall" , nccl_alltoall);
422
447
m.impl (" nccl_one2many" , nccl_one2many);
423
448
m.impl (" nccl_many2one" , nccl_many2one);
449
+ m.impl (" nccl_broadcast" , nccl_broadcast);
424
450
m.impl (" nccl_reducescatter" , nccl_reducescatter);
425
451
m.impl (" one_shot_car_allreduce" , one_shot_car_allreduce);
426
452
m.impl (" two_shot_car_allreduce" , two_shot_car_allreduce);
@@ -472,6 +498,14 @@ void nccl_many2one_meta(
472
498
return ;
473
499
}
474
500
501
+ void nccl_broadcast_meta (
502
+ at::Tensor /* send*/ ,
503
+ at::Tensor /* recv*/ ,
504
+ int64_t /* root*/ ,
505
+ int64_t /* comm_idx*/ ) {
506
+ return ;
507
+ }
508
+
475
509
void nccl_reducescatter_meta (
476
510
at::Tensor /* dst */ ,
477
511
at::Tensor /* src */ ,
@@ -512,6 +546,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
512
546
m.impl (" nccl_alltoall" , nccl_alltoall_meta);
513
547
m.impl (" nccl_one2many" , nccl_one2many_meta);
514
548
m.impl (" nccl_many2one" , nccl_many2one_meta);
549
+ m.impl (" nccl_broadcast" , nccl_broadcast_meta);
515
550
m.impl (" nccl_reducescatter" , nccl_reducescatter_meta);
516
551
m.impl (" one_shot_car_allreduce" , one_shot_car_allreduce_meta);
517
552
m.impl (" two_shot_car_allreduce" , two_shot_car_allreduce_meta);
0 commit comments