Skip to content

Commit 7e0ee2b

Browse files
authored
Add CMake option to enable saturation checker for ConvSymKernelAvx2 (microsoft#24220)
### Description <!-- Describe your changes. --> This PR adds a new CMake option: onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER. When enabled, this option activates a saturation checker for the VPMADDUBSW instruction used in the ConvSymKernelAvx2 path. The checker works by calling a helper function before each VPMADDUBSW instruction. This function simulates the computation using C++ and intrinsics with higher-precision types (int32_t) to detect whether the result exceeds the bounds of int16_t (i.e., greater than INT16_MAX or less than INT16_MIN). By default, the checker logs a warning only once per inference session. However, the logic can be easily extended to print more frequently if needed. Developers can also reuse this pattern to implement similar saturation checks for other instructions. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> On some models running with AVX2 (instead of AVX-VNNI), we've observed accuracy degradation due to saturation in vectorized instructions. This saturation checker provides a way to debug and detect those cases by reporting potential overflow in intermediate computations.
1 parent edb7a2a commit 7e0ee2b

File tree

9 files changed

+298
-0
lines changed

9 files changed

+298
-0
lines changed

cmake/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ option(onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE "Use a custom SDL Ru
6565
option(onnxruntime_ENABLE_PYTHON "Enable python bindings" OFF)
6666
# Enable it may cause LNK1169 error
6767
option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF)
68+
option(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER "Experimental: Enable ConvSymKernelAvx2 assembly saturation checker in build" OFF)
6869
option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
6970
# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through
7071
# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical

cmake/onnxruntime_mlas.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
4646
${MLAS_SRC_DIR}/rotary_embedding.h
4747
${MLAS_SRC_DIR}/rotary_embedding.cpp
4848
${MLAS_SRC_DIR}/softmax.h
49+
${MLAS_SRC_DIR}/saturation_check.cpp
4950
)
5051

5152
target_sources(onnxruntime_mlas PRIVATE
@@ -239,6 +240,10 @@ function(setup_mlas_source_for_windows)
239240
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
240241
)
241242

243+
if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER)
244+
set_source_files_properties(${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm PROPERTIES COMPILE_FLAGS "-DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER")
245+
endif()
246+
242247
if(MSVC_VERSION GREATER_EQUAL 1933)
243248
target_sources(onnxruntime_mlas PRIVATE
244249
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
@@ -637,6 +642,7 @@ else()
637642
${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S
638643
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
639644
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
645+
${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp
640646
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
641647
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
642648
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
@@ -716,6 +722,10 @@ endif()
716722
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
717723
endif()
718724

725+
if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER)
726+
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER")
727+
endif()
728+
719729
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
720730
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
721731
set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64")

onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,87 @@ INCLUDE ConvSymKernelCommon.inc
2323
INCLUDE AssembleAvxVnni.inc
2424
.list
2525

26+
extern CheckSaturationForVPMADDUBSW:proc
27+
28+
CheckSaturation MACRO VecReg1Num, VecReg2Num
29+
30+
;
31+
; Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11). no RSI, RDI.
32+
;
33+
34+
push_reg rax
35+
push_reg rcx
36+
push_reg rdx
37+
push_reg r8
38+
push_reg r9
39+
push_reg r10
40+
push_reg r11
41+
42+
sub rsp, 512 ; reserve space for 16 YMM registers (32 bytes)
43+
44+
;
45+
; Save YMM registers (YMM0 to YMM15).
46+
;
47+
48+
vmovdqu YMMWORD PTR [rsp], ymm0
49+
vmovdqu YMMWORD PTR [rsp+32], ymm1
50+
vmovdqu YMMWORD PTR [rsp+64], ymm2
51+
vmovdqu YMMWORD PTR [rsp+96], ymm3
52+
vmovdqu YMMWORD PTR [rsp+128], ymm4
53+
vmovdqu YMMWORD PTR [rsp+160], ymm5
54+
vmovdqu YMMWORD PTR [rsp+192], ymm6
55+
vmovdqu YMMWORD PTR [rsp+224], ymm7
56+
vmovdqu YMMWORD PTR [rsp+256], ymm8
57+
vmovdqu YMMWORD PTR [rsp+288], ymm9
58+
vmovdqu YMMWORD PTR [rsp+320], ymm10
59+
vmovdqu YMMWORD PTR [rsp+352], ymm11
60+
vmovdqu YMMWORD PTR [rsp+384], ymm12
61+
vmovdqu YMMWORD PTR [rsp+416], ymm13
62+
vmovdqu YMMWORD PTR [rsp+448], ymm14
63+
vmovdqu YMMWORD PTR [rsp+480], ymm15
64+
65+
lea rcx, [rsp+32*VecReg1Num] ; first operand (unsigned)
66+
lea rdx, [rsp+32*VecReg2Num] ; second operand (signed)
67+
68+
call CheckSaturationForVPMADDUBSW
69+
70+
;
71+
; Restore YMM registers.
72+
;
73+
74+
vmovdqu ymm0, YMMWORD PTR [rsp]
75+
vmovdqu ymm1, YMMWORD PTR [rsp+32]
76+
vmovdqu ymm2, YMMWORD PTR [rsp+64]
77+
vmovdqu ymm3, YMMWORD PTR [rsp+96]
78+
vmovdqu ymm4, YMMWORD PTR [rsp+128]
79+
vmovdqu ymm5, YMMWORD PTR [rsp+160]
80+
vmovdqu ymm6, YMMWORD PTR [rsp+192]
81+
vmovdqu ymm7, YMMWORD PTR [rsp+224]
82+
vmovdqu ymm8, YMMWORD PTR [rsp+256]
83+
vmovdqu ymm9, YMMWORD PTR [rsp+288]
84+
vmovdqu ymm10, YMMWORD PTR [rsp+320]
85+
vmovdqu ymm11, YMMWORD PTR [rsp+352]
86+
vmovdqu ymm12, YMMWORD PTR [rsp+384]
87+
vmovdqu ymm13, YMMWORD PTR [rsp+416]
88+
vmovdqu ymm14, YMMWORD PTR [rsp+448]
89+
vmovdqu ymm15, YMMWORD PTR [rsp+480]
90+
91+
add rsp, 512 ; clean up the reserved stack space
92+
93+
;
94+
; Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11), no RSI, RDI.
95+
;
96+
97+
pop r11
98+
pop r10
99+
pop r9
100+
pop r8
101+
pop rdx
102+
pop rcx
103+
pop rax
104+
105+
ENDM
106+
26107
;
27108
; Macro Description:
28109
;
@@ -50,9 +131,15 @@ INCLUDE AssembleAvxVnni.inc
50131

51132
MultiplyAccumulateRowAvx2 MACRO Vec1Reg, Vec2Reg
52133

134+
IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER
135+
CheckSaturation 2,0
136+
ENDIF
53137
vpmaddubsw ymm3,ymm2,ymm0
54138
vpmaddwd ymm3,ymm3,ymm12
55139
vpaddd Vec1Reg,Vec1Reg,ymm3
140+
IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER
141+
CheckSaturation 2,1
142+
ENDIF
56143
vpmaddubsw ymm2,ymm2,ymm1
57144
vpmaddwd ymm2,ymm2,ymm12
58145
vpaddd Vec2Reg,Vec2Reg,ymm2
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
saturation_check_avx2.cpp
10+
11+
Abstract:
12+
13+
This module implements logic to check saturation of the VPMADDUBSW
14+
instruction.
15+
16+
--*/
17+
18+
#include <immintrin.h>
19+
20+
#include <atomic>
21+
#include <iostream>
22+
23+
namespace onnxruntime
24+
{
25+
extern std::atomic<int> saturation_count;
26+
}
27+
28+
extern "C" void
29+
CheckSaturationForVPMADDUBSW(const __m256i* unsigned_ptr, const __m256i* signed_ptr)
30+
{
31+
// Load data from memory (unaligned load)
32+
__m256i unsigned_data = _mm256_loadu_si256(unsigned_ptr);
33+
__m256i signed_data = _mm256_loadu_si256(signed_ptr);
34+
35+
alignas(32) uint8_t unsigned_bytes[32]; // Unsigned input values
36+
alignas(32) int8_t signed_bytes[32]; // Signed input values
37+
38+
// Store the data into the byte arrays
39+
_mm256_store_si256(reinterpret_cast<__m256i*>(unsigned_bytes), unsigned_data);
40+
_mm256_store_si256(reinterpret_cast<__m256i*>(signed_bytes), signed_data);
41+
42+
bool saturation_detected = false;
43+
44+
// Iterate through the 16 pairs of 8-bit unsigned and signed values
45+
for (int i = 0; i < 16; ++i) {
46+
// Perform the VPMADDUBSW operation in higher precision (int32_t)
47+
int32_t computed_value =
48+
static_cast<int32_t>(signed_bytes[2 * i]) * static_cast<int32_t>(static_cast<uint32_t>(unsigned_bytes[2 * i])) +
49+
static_cast<int32_t>(signed_bytes[2 * i + 1]) * static_cast<int32_t>(static_cast<uint32_t>(unsigned_bytes[2 * i + 1]));
50+
51+
// If the computed value exceeds the 16-bit signed integer range, saturation occurred
52+
if (computed_value > INT16_MAX || computed_value < INT16_MIN) {
53+
saturation_detected = true;
54+
break;
55+
}
56+
}
57+
58+
// If saturation is detected, log a warning (only log once based on the atomic count)
59+
if (saturation_detected && ++onnxruntime::saturation_count < 2) {
60+
std::cerr << "Warning: saturation detected in VPMADDUBSW instruction." << std::endl;
61+
}
62+
}

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Module Name:
1818
#pragma once
1919

2020
#include <algorithm>
21+
#include <atomic>
2122
#include <cmath>
2223
#include <functional>
2324
#include <limits>
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
saturation_check.cpp
10+
11+
Abstract:
12+
13+
This module implements logic to check saturation of the VPMADDUBSW
14+
instruction.
15+
16+
--*/
17+
18+
#include "mlasi.h"
19+
20+
namespace onnxruntime
21+
{
22+
23+
#if defined(MLAS_TARGET_AMD64)
24+
25+
std::atomic<int> saturation_count{0};
26+
27+
void
28+
reset_saturation_count()
29+
{
30+
saturation_count = 0;
31+
}
32+
33+
#else
34+
35+
void
36+
reset_saturation_count()
37+
{
38+
}
39+
40+
#endif
41+
42+
} // namespace onnxruntime

onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,91 @@ Abstract:
2323

2424
.intel_syntax noprefix
2525

26+
.extern CheckSaturationForVPMADDUBSW
27+
28+
.macro CheckSaturation VecReg1Num, VecReg2Num
29+
30+
//
31+
// Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11)
32+
//
33+
34+
push rax
35+
push rcx
36+
push rdx
37+
push rsi
38+
push rdi
39+
push r8
40+
push r9
41+
push r10
42+
push r11
43+
44+
sub rsp, 512 # reserve space for 16 YMM registers (32 bytes)
45+
46+
//
47+
// Save YMM registers (YMM0 to YMM15)
48+
//
49+
50+
vmovdqu [rsp], ymm0
51+
vmovdqu [rsp+32], ymm1
52+
vmovdqu [rsp+64], ymm2
53+
vmovdqu [rsp+96], ymm3
54+
vmovdqu [rsp+128], ymm4
55+
vmovdqu [rsp+160], ymm5
56+
vmovdqu [rsp+192], ymm6
57+
vmovdqu [rsp+224], ymm7
58+
vmovdqu [rsp+256], ymm8
59+
vmovdqu [rsp+288], ymm9
60+
vmovdqu [rsp+320], ymm10
61+
vmovdqu [rsp+352], ymm11
62+
vmovdqu [rsp+384], ymm12
63+
vmovdqu [rsp+416], ymm13
64+
vmovdqu [rsp+448], ymm14
65+
vmovdqu [rsp+480], ymm15
66+
67+
lea rdi, [rsp+32*\VecReg1Num\()] # first operand (unsigned)
68+
lea rsi, [rsp+32*\VecReg2Num\()] # second operand (signed)
69+
70+
call CheckSaturationForVPMADDUBSW
71+
72+
//
73+
// Restore YMM registers
74+
//
75+
76+
vmovdqu ymm0, [rsp]
77+
vmovdqu ymm1, [rsp+32]
78+
vmovdqu ymm2, [rsp+64]
79+
vmovdqu ymm3, [rsp+96]
80+
vmovdqu ymm4, [rsp+128]
81+
vmovdqu ymm5, [rsp+160]
82+
vmovdqu ymm6, [rsp+192]
83+
vmovdqu ymm7, [rsp+224]
84+
vmovdqu ymm8, [rsp+256]
85+
vmovdqu ymm9, [rsp+288]
86+
vmovdqu ymm10, [rsp+320]
87+
vmovdqu ymm11, [rsp+352]
88+
vmovdqu ymm12, [rsp+384]
89+
vmovdqu ymm13, [rsp+416]
90+
vmovdqu ymm14, [rsp+448]
91+
vmovdqu ymm15, [rsp+480]
92+
93+
add rsp, 512 # clean up the reserved stack space
94+
95+
//
96+
// Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11)
97+
//
98+
99+
pop r11
100+
pop r10
101+
pop r9
102+
pop r8
103+
pop rdi
104+
pop rsi
105+
pop rdx
106+
pop rcx
107+
pop rax
108+
109+
.endm
110+
26111
/*++
27112

28113
Macro Description:
@@ -52,9 +137,15 @@ Implicit Arguments:
52137

53138
.macro MultiplyAccumulateRowAvx2 Vec1Reg, Vec2Reg
54139

140+
#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER)
141+
CheckSaturation 2,0
142+
#endif
55143
vpmaddubsw ymm3,ymm2,ymm0
56144
vpmaddwd ymm3,ymm3,ymm12
57145
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
146+
#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER)
147+
CheckSaturation 2,1
148+
#endif
58149
vpmaddubsw ymm2,ymm2,ymm1
59150
vpmaddwd ymm2,ymm2,ymm12
60151
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2

onnxruntime/core/session/inference_session.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,8 @@ Status InferenceSession::Run(const RunOptions& run_options,
28652865
}
28662866
#endif
28672867

2868+
reset_saturation_count();
2869+
28682870
// As N+1 inference runs (N for memory allocation and 1 for graph capturing)
28692871
// are needed before replaying the captured graph, here run N inference runs recursively until graph captured,
28702872
// so that users just need one session run to capture the graph.

onnxruntime/core/session/inference_session.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class IExecutionProvider;
5858
class IOBinding;
5959
struct Notification;
6060

61+
void reset_saturation_count();
62+
6163
#ifdef ENABLE_TRAINING
6264
struct PartialGraphExecutionState;
6365
using OrtValueCache = InlinedHashMap<std::string, OrtValue>;

0 commit comments

Comments
 (0)