Skip to content

Commit 58dd9c8

Browse files
add FlagCX as optional communication library when using iluvatar gpus (#1862)
1 parent b69382b commit 58dd9c8

File tree

6 files changed

+491
-8
lines changed

6 files changed

+491
-8
lines changed

backends/iluvatar_gpu/CMakeLists.txt

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ include(external/eigen)
3333
include(external/xxhash)
3434
include(external/zlib)
3535
include(external/protobuf)
36+
if(WITH_FLAGCX)
37+
add_definitions("-DPADDLE_WITH_FLAGCX")
38+
include(external/flagcx)
39+
endif()
3640

3741
set(PLUGIN_VERSION ${PADDLE_VERSION})
3842
set(PROTO_FILE "${PADDLE_SOURCE_DIR}/paddle/phi/core/external_error.proto")
@@ -66,8 +70,11 @@ target_include_directories(external_error_proto
6670
target_link_libraries(external_error_proto PUBLIC protobuf)
6771
set_target_properties(external_error_proto PROPERTIES POSITION_INDEPENDENT_CODE
6872
ON)
69-
70-
add_custom_target(external_deps DEPENDS eigen3 zlib protobuf)
73+
if(WITH_FLAGCX)
74+
add_custom_target(external_deps DEPENDS eigen3 zlib protobuf flagcx)
75+
else()
76+
add_custom_target(external_deps DEPENDS eigen3 zlib protobuf)
77+
endif()
7178

7279
if(WITH_COREX)
7380
add_definitions(-DPADDLE_WITH_COREX)
@@ -247,6 +254,10 @@ file(
247254
kernels/ernie_core/*.cc
248255
kernels/gpudnn/*.cu)
249256

257+
if(WITH_FLAGCX)
258+
list(APPEND CC_SRCS runtime/runtime_flagcx.cc)
259+
endif()
260+
250261
set(CUSTOM_DEVICE_SRCS ${CUDA_SRCS} ${CC_SRCS})
251262

252263
set_source_files_properties(${CUSTOM_DEVICE_SRCS} PROPERTIES LANGUAGE CUDA)
@@ -269,7 +280,9 @@ target_link_libraries(
269280
protobuf
270281
external_error_proto
271282
cuinfer
272-
nccl)
283+
nccl
284+
# change nccl to ${FLAGCX_LIB} if compiling with FlagCX ${FLAGCX_LIB}
285+
)
273286

274287
include_directories(BEFORE ${PADDLE_SOURCE_DIR})
275288

backends/iluvatar_gpu/build_paddle.sh

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/bin/bash
22

33
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4-
#
4+
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
77
# You may obtain a copy of the License at
8-
#
8+
#
99
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
10+
#
1111
# Unless required by applicable law or agreed to in writing, software
1212
# distributed under the License is distributed on an "AS IS" BASIS,
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -26,6 +26,15 @@ export CMAKE_CUDA_ARCHITECTURES=${COREX_ARCH}
2626
CURRENT_DIR=$(pwd)
2727
PADDLE_SOURCE_DIR="${CURRENT_DIR}/../../Paddle"
2828
PATCH_FILE="${CURRENT_DIR}/patches/paddle-corex.patch"
29+
# set BUILD_WITH_FLAGCX to 1 if we want to use flagcx as communication backend
30+
BUILD_WITH_FLAGCX=0
31+
FLAGCX_ROOT="/workspace/FlagCX"
32+
33+
if [ "$BUILD_WITH_FLAGCX" == "1" ]; then
34+
WITH_FLAGCX="ON"
35+
else
36+
WITH_FLAGCX="OFF"
37+
fi
2938

3039
bash clean_paddle.sh
3140

@@ -51,9 +60,10 @@ if [[ ! -d "build" ]]; then
5160
fi
5261
pushd build
5362

54-
cmake -DPY_VERSION=${PYTHON_VERSION} -DWITH_COREX=ON \
55-
-DWITH_DISTRIBUTE=ON -DWITH_NCCL=ON -DWITH_RCCL=OFF -DCMAKE_BUILD_TYPE=Release \
63+
cmake -DPY_VERSION=${PYTHON_VERSION} -DWITH_COREX=ON -DPADDLE_SOURCE_DIR=${PADDLE_SOURCE_DIR} \
64+
-DWITH_DISTRIBUTE=ON -DWITH_NCCL=ON -DWITH_FLAGCX=${WITH_FLAGCX} -DWITH_RCCL=OFF -DCMAKE_BUILD_TYPE=Release \
5665
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DON_INFER=ON -DCOREX_VERSION=${COREX_VERSION} -DCOREX_ARCH=${COREX_ARCH} \
66+
-DFLAGCX_ROOT=${FLAGCX_ROOT} \
5767
-DCMAKE_CXX_FLAGS='-Wno-error=pessimizing-move -Wno-error=deprecated-copy -Wno-error=init-list-lifetime' \
5868
-DCMAKE_CUDA_FLAGS='-Xclang -fcuda-allow-variadic-functions -mllvm --skip-double' \
5969
-DWITH_ARM=OFF -DWITH_DGC=OFF .. 2>&1 | tee compile.log
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
set(CMAKE_FIND_DEBUG_MODE ON)
2+
# flagcx.cmake
3+
if(NOT WITH_FLAGCX)
4+
return()
5+
endif()
6+
7+
set(FLAGCX_SOURCE_DIR "${FLAGCX_ROOT}")
8+
set(FLAGCX_LIB_DIR "${FLAGCX_SOURCE_DIR}/build/lib")
9+
set(FLAGCX_BINARY_DIR "${PADDLE_SOURCE_DIR}/build/third_party/flagcx")
10+
set(THIRD_PARTY_DIR "${PADDLE_SOURCE_DIR}/build/third_party")
11+
12+
file(REMOVE_RECURSE ${FLAGCX_BINARY_DIR})
13+
message(STATUS "removed old flagcx dir")
14+
message(STATUS "Copying third-party source to build directory")
15+
execute_process(COMMAND cp -r ${FLAGCX_SOURCE_DIR} ${THIRD_PARTY_DIR}
16+
RESULT_VARIABLE COPY_RESULT)
17+
18+
if(NOT COPY_RESULT EQUAL 0)
19+
message(FATAL_ERROR "Failed to copy third-party source to build directory")
20+
endif()
21+
22+
# Create a custom target to build the third-party library
23+
message(STATUS "Building third-party library with its Makefile")
24+
25+
find_path(
26+
FLAGCX_INCLUDE_DIR flagcx.h
27+
PATHS ${FLAGCX_SOURCE_DIR}/flagcx/include
28+
NO_DEFAULT_PATH)
29+
30+
message(STATUS "FLAGCX_INCLUDE_DIR is ${FLAGCX_INCLUDE_DIR}")
31+
include_directories(SYSTEM ${FLAGCX_INCLUDE_DIR})
32+
33+
add_library(flagcx INTERFACE)
34+
find_library(
35+
FLAGCX_LIB
36+
NAMES flagcx libflagcx
37+
PATHS ${FLAGCX_LIB_DIR}
38+
DOC "My custom library")
39+
40+
add_dependencies(flagcx FLAGCX_LIB)
41+
message(STATUS "FLAGCX_LIB is ${FLAGCX_LIB}")

backends/iluvatar_gpu/runtime/runtime.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include <errno.h>
1616
#include <fcntl.h>
1717
#include <nccl.h>
18+
#if defined(PADDLE_WITH_FLAGCX)
19+
#include "runtime_flagcx.h" // NOLINT
20+
#endif
1821
#include <semaphore.h>
1922
#include <sys/types.h>
2023
#include <sys/wait.h>
@@ -1023,6 +1026,9 @@ void InitPlugin(CustomRuntimeParams *params) {
10231026
0,
10241027
sizeof(C_DeviceInterface));
10251028

1029+
#if defined(PADDLE_WITH_FLAGCX)
1030+
flagcxHandleInit(&flagcx_handler);
1031+
#endif
10261032
params->interface->get_compute_capability = GetComputeCapability;
10271033
params->interface->get_runtime_version = GetRuntimeVersion;
10281034
params->interface->get_driver_version = GetDriverVersion;
@@ -1073,6 +1079,22 @@ void InitPlugin(CustomRuntimeParams *params) {
10731079
params->interface->init_eigen_device = InitEigenDevice;
10741080
params->interface->destroy_eigen_device = DestroyEigenDevice;
10751081

1082+
#if defined(PADDLE_WITH_FLAGCX)
1083+
params->interface->xccl_all_gather = XcclFlagcxAllGather;
1084+
params->interface->xccl_all_reduce = XcclFlagcxAllReduce;
1085+
params->interface->xccl_broadcast = XcclFlagcxBroadcast;
1086+
params->interface->xccl_comm_init_rank = XcclFlagcxCommInitRank;
1087+
params->interface->xccl_destroy_comm = XcclFlagcxDestroyComm;
1088+
params->interface->xccl_get_unique_id = XcclFlagcxGetUniqueId;
1089+
params->interface->xccl_get_unique_id_size = XcclFlagcxGetUniqueIdSize;
1090+
params->interface->xccl_group_end = XcclFlagcxGroupEnd;
1091+
params->interface->xccl_group_start = XcclFlagcxGroupStart;
1092+
params->interface->xccl_recv = XcclFlagcxRecv;
1093+
params->interface->xccl_reduce = XcclFlagcxReduce;
1094+
params->interface->xccl_reduce_scatter = XcclFlagcxReduceScatter;
1095+
params->interface->xccl_send = XcclFlagcxSend;
1096+
params->interface->xccl_all_to_all = XcclFlagcxAllToAll;
1097+
#else
10761098
params->interface->xccl_all_gather = XcclAllGather;
10771099
params->interface->xccl_all_reduce = XcclAllReduce;
10781100
params->interface->xccl_broadcast = XcclBroadcast;
@@ -1086,6 +1108,7 @@ void InitPlugin(CustomRuntimeParams *params) {
10861108
params->interface->xccl_reduce = XcclReduce;
10871109
params->interface->xccl_reduce_scatter = XcclReduceScatter;
10881110
params->interface->xccl_send = XcclSend;
1111+
#endif
10891112

10901113
params->interface->profiler_collect_trace_data = nullptr;
10911114
params->interface->profiler_initialize = nullptr;

0 commit comments

Comments
 (0)