File tree Expand file tree Collapse file tree 3 files changed +16
-5
lines changed Expand file tree Collapse file tree 3 files changed +16
-5
lines changed Original file line number Diff line number Diff line change @@ -69,9 +69,9 @@ extern "C" {
6969 return func (comm);
7070 }
7171
72- ncclResult_t ncclCommAbort (ncclComm_t comm) {
72+ ncclResult_t ncclCommFinalize (ncclComm_t comm) {
7373 using Signature = ncclResult_t (*)(ncclComm_t comm);
74- static auto func = ctranslate2::load_symbol<Signature>(" ncclCommAbort " );
74+ static auto func = ctranslate2::load_symbol<Signature>(" ncclCommFinalize " );
7575 return func (comm);
7676 }
7777
Original file line number Diff line number Diff line change @@ -196,7 +196,7 @@ namespace ctranslate2 {
196196 for (auto * comm : _nccl_comms) {
197197 // finalizing NCCL
198198 if (*comm) {
199- NCCL_CHECK (ncclCommAbort (*comm));
199+ NCCL_CHECK (ncclCommFinalize (*comm));
200200 NCCL_CHECK (ncclCommDestroy (*comm));
201201 }
202202 }
Original file line number Diff line number Diff line change 11#include " ctranslate2/layers/attention.h"
22#include " ctranslate2/ops/split.h"
3+ #include " ctranslate2/utils.h"
4+
35
46#include < algorithm>
57#include < cmath>
@@ -210,11 +212,20 @@ namespace ctranslate2 {
210212 is_decoder,
211213 with_cache ? key_length - 1 : 0 );
212214 }
215+ StorageView* position_bias_per_gpu = position_bias;
216+ StorageView position_bias_tmp (position_bias->dtype (), position_bias->device ());
217+ if (ScopedMPISetter::getCurRank () != 0 ) {
218+ const dim_t num_head_per_gpu = SAFE_DIVIDE (position_bias->dim (0 ), ScopedMPISetter::getNRanks ());
219+ ops::Slide slide_ops (0 , num_head_per_gpu * ScopedMPISetter::getCurRank (),
220+ num_head_per_gpu, true );
221+ slide_ops (*position_bias, position_bias_tmp);
222+ position_bias_per_gpu = &position_bias_tmp;
223+ }
213224
214225 DEVICE_AND_TYPE_DISPATCH (output.device (), output.dtype (),
215- primitives<D>::add_batch_broadcast (position_bias ->data <T>(),
226+ primitives<D>::add_batch_broadcast (position_bias_per_gpu ->data <T>(),
216227 output.data <T>(),
217- position_bias ->size (),
228+ position_bias_per_gpu ->size (),
218229 output.size ()));
219230 }
220231
You can’t perform that action at this time.
0 commit comments