Skip to content

Commit 0a4c252

Browse files
committed
include/hipify refactor
1 parent ff54a6a commit 0a4c252

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

transformer_engine/common/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ else()
238238
IGNORES "*/amd_detail/*"
239239
IGNORES "*/aotriton/*"
240240
IGNORES "*/ck_fused_attn/*"
241+
IGNORES "*/rocshmem_api/*"
241242
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
242243
NO_MATH_REPLACE
243244
)
@@ -385,6 +386,9 @@ target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAM
385386
else()
386387
option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF)
387388
if (NVTE_ENABLE_ROCSHMEM)
389+
find_package(MPI REQUIRED)
390+
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
391+
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
388392
add_subdirectory(rocshmem_api)
389393
if(DEFINED ENV{ROCSHMEM_HOME})
390394
set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation")

transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
#ifdef NVTE_ENABLE_ROCSHMEM
99
#include <mpi.h>
1010
#include <rocshmem_api/rocshmem_waitkernel.hpp>
11-
#endif
12-
13-
#include <cuda.h>
14-
#include <cuda_fp8.h>
11+
#include <hip/hip_runtime.h>
1512
#include <torch/cuda.h>
1613
#include <torch/extension.h>
14+
#endif
1715

1816
namespace transformer_engine::pytorch {
1917

0 commit comments

Comments
 (0)