Skip to content

Commit 5eb5d5a

Browse files
authored
fix position bias in tensor parallel (#1714)
* fix position bias in tensor parallel * add symbol ncclCommFinalize
1 parent 3b248f1 commit 5eb5d5a

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

src/cuda/nccl_stub.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ extern "C" {
6969
return func(comm);
7070
}
7171

72-
ncclResult_t ncclCommAbort(ncclComm_t comm) {
72+
ncclResult_t ncclCommFinalize(ncclComm_t comm) {
7373
using Signature = ncclResult_t(*)(ncclComm_t comm);
74-
static auto func = ctranslate2::load_symbol<Signature>("ncclCommAbort");
74+
static auto func = ctranslate2::load_symbol<Signature>("ncclCommFinalize");
7575
return func(comm);
7676
}
7777

src/devices.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ namespace ctranslate2 {
196196
for (auto* comm : _nccl_comms) {
197197
//finalizing NCCL
198198
if (*comm) {
199-
NCCL_CHECK(ncclCommAbort(*comm));
199+
NCCL_CHECK(ncclCommFinalize(*comm));
200200
NCCL_CHECK(ncclCommDestroy(*comm));
201201
}
202202
}

src/layers/attention.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "ctranslate2/layers/attention.h"
22
#include "ctranslate2/ops/split.h"
3+
#include "ctranslate2/utils.h"
4+
35

46
#include <algorithm>
57
#include <cmath>
@@ -210,11 +212,20 @@ namespace ctranslate2 {
210212
is_decoder,
211213
with_cache ? key_length - 1 : 0);
212214
}
215+
StorageView* position_bias_per_gpu = position_bias;
216+
StorageView position_bias_tmp(position_bias->dtype(), position_bias->device());
217+
if (ScopedMPISetter::getCurRank() != 0) {
218+
const dim_t num_head_per_gpu = SAFE_DIVIDE(position_bias->dim(0), ScopedMPISetter::getNRanks());
219+
ops::Slide slide_ops(0, num_head_per_gpu * ScopedMPISetter::getCurRank(),
220+
num_head_per_gpu, true);
221+
slide_ops(*position_bias, position_bias_tmp);
222+
position_bias_per_gpu = &position_bias_tmp;
223+
}
213224

214225
DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(),
215-
primitives<D>::add_batch_broadcast(position_bias->data<T>(),
226+
primitives<D>::add_batch_broadcast(position_bias_per_gpu->data<T>(),
216227
output.data<T>(),
217-
position_bias->size(),
228+
position_bias_per_gpu->size(),
218229
output.size()));
219230
}
220231

0 commit comments

Comments
 (0)