Skip to content

Commit 99804a2

Browse files
authored
Fix MPI related issue in GPT model parallel (#1894)
* Add mpi include dir into cmake. Support more MPI environment vars. Fix HOST confict of MPI and Paddle. * Fix pointer convert in gpt_op.
1 parent 01d2c2a commit 99804a2

File tree

3 files changed

+44
-24
lines changed

3 files changed

+44
-24
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ option(WITH_PARALLEL "Compile with model parallel for GPT"
3838
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
3939
if(WITH_PARALLEL)
4040
# https://cmake.org/cmake/help/latest/module/FindMPI.html#variables-for-locating-mpi
41+
# https://github.com/Kitware/CMake/blob/master/Modules/FindMPI.cmake
4142
find_package(MPI REQUIRED)
4243
find_package(NCCL REQUIRED)
4344
add_definitions(-DBUILD_GPT)
@@ -241,6 +242,12 @@ set(COMMON_LIB_DIRS
241242
${CUDA_PATH}/lib64
242243
)
243244

245+
if(WITH_PARALLEL)
246+
list(APPEND COMMON_HEADER_DIRS
247+
${NCCL_INCLUDE_PATH}
248+
${MPI_INCLUDE_PATH})
249+
endif()
250+
244251
set(THIRD_PATH "third-party")
245252
set(THIRD_PARTY_NAME "fastertransformer")
246253

paddlenlp/ops/faster_transformer/src/fusion_gpt_op.cu

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#include <vector>
1010

1111
// TODO(guosheng): `HOST` conflict exists in float.h of paddle and mpi.h of mpi
12+
#include "fusion_gpt_op.h"
13+
#include "pd_traits.h"
14+
#ifdef HOST
15+
#undef HOST
16+
#endif
1217
#include "fastertransformer/cuda/cub/cub.cuh"
1318
#include "fastertransformer/gpt.h"
14-
#include "fastertransformer/open_decoder.h"
1519
#include "fastertransformer/utils/common.h"
16-
#include "fusion_gpt_op.h"
17-
#include "pd_traits.h"
1820

1921
#ifdef BUILD_GPT // consistent with FasterTransformer
2022
#include <map>
@@ -274,14 +276,14 @@ std::vector<paddle::Tensor> gpt2_kernel(
274276
const int hidden_unit = size_per_head * n_head;
275277

276278
#ifdef BUILD_GPT
277-
auto* model_para_desc =
278-
ModelParaDescFactory::CreateModelParaDesc(n_head,
279-
size_per_head,
280-
num_layer,
281-
tensor_para_size,
282-
layer_para_size,
283-
layer_para_batch_size,
284-
word_emb.data<data_t_>());
279+
auto* model_para_desc = ModelParaDescFactory::CreateModelParaDesc(
280+
n_head,
281+
size_per_head,
282+
num_layer,
283+
tensor_para_size,
284+
layer_para_size,
285+
layer_para_batch_size,
286+
const_cast<data_t_*>(word_emb.data<data_t_>()));
285287
auto& tensor_parallel_param = model_para_desc->tensor_parallel_param;
286288
auto& layer_parallel_param = model_para_desc->layer_parallel_param;
287289
auto seed = model_para_desc->dist(model_para_desc->gen);

paddlenlp/ops/faster_transformer/transformer/decoding.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,29 +1175,40 @@ def __init__(self,
11751175
tensor_para_size=None,
11761176
layer_para_size=None,
11771177
layer_para_batch_size=1):
1178-
# Maybe we should import mpi4py later.
1179-
self.word_size = int(
1180-
os.environ.get(
1181-
"MPI_LOCALNRANKS", # MPICH
1182-
os.environ.get("OMPI_COMM_WORLD_SIZE", 1))) # OpenMPI
1183-
self.rank = int(
1184-
os.environ.get(
1185-
"MPI_LOCALRANKID", # MPICH
1186-
os.environ.get("OMPI_COMM_WORLD_RANK", 0))) # OpenMPI
1187-
if tensor_para_size is None: tensor_para_size = 1
1188-
if layer_para_size is None:
1189-
layer_para_size = self.word_size // tensor_para_size
1178+
self.world_size = self._env2int(
1179+
[ # MPICH, OpenMPI, IMPI
1180+
"MPI_LOCALNRANKS", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE",
1181+
"MV2_COMM_WORLD_SIZE", "WORLD_SIZE"
1182+
],
1183+
1)
1184+
self.rank = self._env2int(
1185+
[ # MPICH, OpenMPI, IMPI
1186+
"MPI_LOCALRANKID", "OMPI_COMM_WORLD_RANK", "PMI_RANK",
1187+
"MV2_COMM_WORLD_RANK", "RANK"
1188+
],
1189+
0)
1190+
if layer_para_size is None: layer_para_size = 1
1191+
if tensor_para_size is None:
1192+
tensor_para_size = self.world_size // layer_para_size
11901193
self.no_para = tensor_para_size == 1 and layer_para_size == 1
11911194
self.tensor_para_size = tensor_para_size
11921195
self.layer_para_size = layer_para_size
11931196
self.layer_para_batch_size = layer_para_batch_size
11941197

1195-
assert self.word_size == tensor_para_size * layer_para_size, (
1198+
assert self.world_size == tensor_para_size * layer_para_size, (
11961199
"tensor_para_size * layer_para_size must be equal to world_size.")
11971200
self.tensor_para_rank = self.rank % self.tensor_para_size
11981201
self.layer_para_rank = self.rank // self.tensor_para_size
11991202
self.is_partial_model = False
12001203

1204+
@staticmethod
1205+
def _env2int(env_list, default=-1):
1206+
for e in env_list:
1207+
val = int(os.environ.get(e, -1))
1208+
if val >= 0:
1209+
return val
1210+
return default
1211+
12011212
def is_last_group(self):
12021213
r"""
12031214
For layer parallel, only the process corresponding to the last layer

0 commit comments

Comments
 (0)