Skip to content

Commit 39daa14

Browse files
authored
Fix and add MXFP8 GEMM test failures (#326)
* Fix MXFP8 GEMM test * Fix uninitialized var in GEMM code * Add Dequantize+GEMM test to check MXFP8 scaling tensor layout
1 parent b08a1ed commit 39daa14

File tree

6 files changed

+239
-150
lines changed

6 files changed

+239
-150
lines changed

ci/core.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ fi
3131
check_test_filter "nongemm"
3232
if [ $? -eq 0 ]; then
3333
echo ===== Run non GEMM tests =====
34-
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "OperatorTest/GEMMTestSuite"
34+
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "GEMMTestSuite"
3535
test $? -eq 0 || test_run_error "non-GEMM"
3636
fi
3737

3838
check_test_filter "gemm"
3939
if [ $? -eq 0 ]; then
4040
echo ===== Run GEMM tests =====
41-
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -R "OperatorTest/GEMMTestSuite"
41+
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -R "GEMMTestSuite"
4242
test $? -eq 0 || test_run_error "GEMM"
4343
fi
4444

tests/cpp/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ else()
6464
project(transformer_engine_tests LANGUAGES HIP CXX)
6565
# Ask hcc to generate device code during compilation so we can use
6666
# host linker to link.
67-
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted")
68-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${HIP_HCC_FLAGS}")
67+
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result")
6968
endif()
7069

7170
add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest)

0 commit comments

Comments
 (0)