Skip to content

Commit f602b36

Browse files
authored
Gfx950 bringup changes (#172)
* Gfx950 bringup changes * Fix pre-ROCm6.3 build
1 parent 1087b26 commit f602b36

File tree

8 files changed

+91
-55
lines changed

8 files changed

+91
-55
lines changed

ci/core.sh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,25 @@ if [ $rc -ne 0 ]; then
2424
exit $rc
2525
fi
2626

27-
echo ===== Run non GEMM tests =====
28-
ctest --test-dir build -j4 --output-on-failure -E "OperatorTest/GEMMTestSuite"
29-
test $? -eq 0 || test_run_error "non-GEMM"
27+
check_test_filter "nongemm"
28+
if [ $? -eq 0 ]; then
29+
echo ===== Run non GEMM tests =====
30+
ctest --test-dir build -j4 --output-on-failure -E "OperatorTest/GEMMTestSuite"
31+
test $? -eq 0 || test_run_error "non-GEMM"
32+
fi
3033

3134
for _gemm in hipblaslt rocblas; do
3235
configure_gemm_env $_gemm || continue
3336
_exclude=""
3437
if [ $_gemm = "hipblaslt" ]; then
3538
_exclude="-E Test(.*bf16/.*X.X1|.*fp8.*fp16/.*X1X0|.*fp8.*X.X1|.*fp8/|.*bf8/)"
3639
fi
37-
echo ===== Run GEMM $_gemm tests =====
38-
ctest --test-dir build -j4 --output-on-failure -R "OperatorTest/GEMMTestSuite" $_exclude
39-
test $? -eq 0 || test_run_error "GEMM $_gemm"
40+
check_test_filter $_gemm
41+
if [ $? -eq 0 ]; then
42+
echo ===== Run GEMM $_gemm tests =====
43+
ctest --test-dir build -j4 --output-on-failure -R "OperatorTest/GEMMTestSuite" $_exclude
44+
test $? -eq 0 || test_run_error "GEMM $_gemm"
45+
fi
4046
done
4147

4248
return_run_results

tests/cpp/operator/test_cublaslt_gemm.cu

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,13 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
155155
}
156156
#endif
157157

158-
Tensor Workspace({ 33554432 }, DType::kByte);
158+
size_t workspace_size = 33554432;
159+
#ifdef __HIP_PLATFORM_AMD__
160+
if (prop.major == 9 && prop.minor == 5) {
161+
workspace_size = 67108864;
162+
}
163+
#endif
164+
Tensor Workspace({ workspace_size }, DType::kByte);
159165

160166
//perform the gemm in GPU
161167
nvte_cublas_gemm(A.data(),
@@ -212,6 +218,18 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
212218
if (dtype == DType::kFloat32) {
213219
atol = 1e-5;
214220
}
221+
#ifdef __HIP_PLATFORM_AMD__
222+
if (prop.major == 9 && prop.minor == 5)
223+
{
224+
// relax for certain gemm with hipblaslt
225+
if (!isFp8Type(dtype) && (isFp8Type(atype) or isFp8Type(btype))) {
226+
atol = 5e-4;
227+
rtol = 5e-3;
228+
} else if (dtype == DType::kFloat32) {
229+
rtol = 1e-5;
230+
}
231+
}
232+
#endif
215233
compareResults("D", D, ref_D.get(), atol, rtol);
216234

217235
if(use_gelu){

tests/pytorch/test_numerics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,8 +1972,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
19721972
if IS_HIP_EXTENSION:
19731973
if use_hipblaslt():
19741974
tols = dtype_tols(dtype)
1975-
if dtype in (torch.float16, torch.bfloat16) and is_mi308():
1976-
# mi308 hipblaslt precision issue
1975+
if dtype in (torch.float16, torch.bfloat16):
1976+
# On some GPUs hipblaslt results for SBHD and BSHD are different
1977+
# that results in lower final result precision
19771978
tols["atol"] = 2e-3
19781979
_, use_aotriton, use_ck = rocm_attn_backend()
19791980
if use_aotriton and not use_ck:

transformer_engine/common/CMakeLists.txt

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ else()
255255
set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton")
256256
set(__AOTRITON_SUFFIX "_TEprivate")
257257
if(NOT DEFINED AOTRITON_PATH)
258-
# # Install aotriton fused attn
258+
# Install aotriton fused attn
259259
if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
260260
set(AOTRITON_NOIMAGE_MODE OFF)
261261
else()
@@ -271,51 +271,60 @@ else()
271271
foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES)
272272
set(key ${X})
273273
string(APPEND key "_key")
274-
string(APPEND aotriton_target_gpus ${${key}})
275-
string(APPEND aotriton_target_gpus "|")
274+
set(gpu ${${key}})
275+
if (gpu)
276+
string(APPEND aotriton_target_gpus "${gpu}|")
277+
else()
278+
message(WARNING "AOTriton building is not supported for ${X}")
279+
endif()
276280
endforeach()
277281
endmacro()
278282
translate_arch_to_gpu_names(aotriton_target_gpus)
279-
include(ExternalProject)
280-
ExternalProject_Add(aotriton_external
281-
SOURCE_DIR ../../3rdparty/aotriton
282-
LIST_SEPARATOR |
283-
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
284-
-DTARGET_GPUS=${aotriton_target_gpus}
285-
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
286-
-DAOTRITON_NO_PYTHON=ON
287-
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
288-
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
289-
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
290-
)
291-
add_library(aotriton INTERFACE)
292-
add_dependencies(aotriton aotriton_external)
293-
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
294-
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
295-
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
296-
set(__AOTRITON_VER "0.8.2b")
297-
set(__AOTRITON_SHA256 "66445e6b0209b9f4080743b839cc9d424054dc5c8d07363f9f27f109231c324a")
298-
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
299-
"${__AOTRITON_VER}/aotriton-"
300-
"${__AOTRITON_VER}-manylinux_2_28"
301-
"_x86_64-rocm6.2"
302-
"-shared.tar.gz")
303-
ExternalProject_Add(aotriton_images
304-
URL "${__AOTRITON_URL}"
305-
URL_HASH SHA256=${__AOTRITON_SHA256}
306-
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
307-
CONFIGURE_COMMAND ""
308-
BUILD_COMMAND ""
309-
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
310-
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images"
311-
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
312-
add_dependencies(aotriton aotriton_images)
283+
if (NOT aotriton_target_gpus)
284+
set(USE_FUSED_ATTN_AOTRITON FALSE)
285+
message(WARNING "Disable AOTriton building because no supported GPU targets found")
286+
else()
287+
include(ExternalProject)
288+
ExternalProject_Add(aotriton_external
289+
SOURCE_DIR ../../3rdparty/aotriton
290+
LIST_SEPARATOR |
291+
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
292+
-DTARGET_GPUS=${aotriton_target_gpus}
293+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
294+
-DAOTRITON_NO_PYTHON=ON
295+
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
296+
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
297+
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
298+
)
299+
add_library(aotriton INTERFACE)
300+
add_dependencies(aotriton aotriton_external)
301+
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
302+
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
303+
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
304+
set(__AOTRITON_VER "0.8.2b")
305+
set(__AOTRITON_SHA256 "66445e6b0209b9f4080743b839cc9d424054dc5c8d07363f9f27f109231c324a")
306+
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
307+
"${__AOTRITON_VER}/aotriton-"
308+
"${__AOTRITON_VER}-manylinux_2_28"
309+
"_x86_64-rocm6.2"
310+
"-shared.tar.gz")
311+
ExternalProject_Add(aotriton_images
312+
URL "${__AOTRITON_URL}"
313+
URL_HASH SHA256=${__AOTRITON_SHA256}
314+
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
315+
CONFIGURE_COMMAND ""
316+
BUILD_COMMAND ""
317+
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
318+
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images"
319+
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
320+
add_dependencies(aotriton aotriton_images)
321+
endif()
322+
install(DIRECTORY
323+
${__AOTRITON_INSTALL_DIR}/lib
324+
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
325+
PATTERN "cmake" EXCLUDE
326+
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)
313327
endif()
314-
install(DIRECTORY
315-
${__AOTRITON_INSTALL_DIR}/lib
316-
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
317-
PATTERN "cmake" EXCLUDE
318-
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)
319328
else()
320329
# Use aotriton built during initial TE building/installation
321330
# When only need rebuild TE library itself

transformer_engine/common/amd_detail/hip_float8.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ struct te_hip_fp8_e4m3 {
8888

8989
__host__ __device__ operator float() const { return data.operator float(); }
9090

91-
__host__ __device__ te_hip_fp8_e4m3(const float& v) { data = v;}
91+
__host__ __device__ te_hip_fp8_e4m3(const float& v): data(v) {}
9292
};
9393
static_assert(sizeof(te_hip_fp8_e4m3) == 1, "Size mismatch");
9494

transformer_engine/common/recipe/delayed_scaling.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ inline float fp8_dtype_max(DType dtype) {
3636
case DType::kFloat8E4M3:
3737
#ifndef __HIP_PLATFORM_AMD__
3838
return 448;
39-
#else
39+
#elif HIP_VERSION >= 60300000
4040
return te_fp8_fnuz() ? 240 : 448;
41+
#else
42+
return 240; // default to true for older versions compatibility
4143
#endif
4244
case DType::kFloat8E5M2:
4345
return 57344;

transformer_engine/common/util/system.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,6 @@ extern "C" bool nvte_uses_fp8_fnuz()
9191
#if HIP_VERSION >= 60300000
9292
return te_fp8_fnuz();
9393
#endif
94-
return true; // default to true for older versions that only support
94+
return true; // default to true for older versions compatibility
9595
}
9696
#endif

transformer_engine/pytorch/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _get_supported_versions(version_min, version_max):
120120
_flash_attn_is_installed = False
121121
_flash_attn_version = PkgVersion("0")
122122
_flash_attn_version_required = PkgVersion("2.1.1")
123-
_flash_attn_max_version = PkgVersion("2.7.3")
123+
_flash_attn_max_version = PkgVersion("2.7.4.post1")
124124
_flash_attn_2_plus = False
125125
_flash_attn_2_1_plus = False
126126
_flash_attn_2_3_plus = False

0 commit comments

Comments
 (0)