1+ // Compilation: nvcc -arch=sm_90 -std=c++17 test_nvfp4_rounding.cu -o test_nvfp4_rounding
2+ // For Godbolt: Add flags: -arch=sm_90 -std=c++17
3+ // Note: Requires SM 9.0+ for FP4 E2M1 intrinsics
4+
5+ #include < cuda_fp16.h>
6+ #include < cuda_bf16.h>
7+ #include < iostream>
8+ #include < vector>
9+ #include < cmath>
10+ #include < iomanip>
11+ #include < cstring>
12+
13+ // Convert float to bfloat16 and back
14+ float to_bfloat16_and_back (float val) {
15+ uint32_t bits = *reinterpret_cast <uint32_t *>(&val);
16+ bits = (bits + 0x8000 ) & 0xFFFF0000 ; // Round to bfloat16
17+ return *reinterpret_cast <float *>(&bits);
18+ }
19+
20+ // FP4 E2M1 format decoder
21+ void decode_fp4_e2m1 (uint8_t fp4, float & value, int & sign, int & exp, int & mantissa) {
22+ sign = (fp4 >> 3 ) & 1 ;
23+ exp = (fp4 >> 1 ) & 3 ;
24+ mantissa = fp4 & 1 ;
25+
26+ // E2M1 decoding with bias=1
27+ if (exp == 0 ) {
28+ // Denormal or zero
29+ if (mantissa == 0 ) {
30+ value = 0 .0f ;
31+ } else {
32+ value = (sign ? -1 .0f : 1 .0f ) * 0 .5f ; // 2^(-1) * 0.5
33+ }
34+ } else {
35+ // Normal number: (-1)^s * 2^(e-1) * (1 + m/2)
36+ float mantissa_val = 1 .0f + mantissa * 0 .5f ;
37+ value = (sign ? -1 .0f : 1 .0f ) * std::pow (2 .0f , exp - 1 ) * mantissa_val;
38+ }
39+ }
40+
41+ // Test kernel for single value
42+ __global__ void test_single_value_kernel (
43+ float input,
44+ uint8_t * cuda_result,
45+ float * debug_info
46+ ) {
47+ // CUDA intrinsic conversion
48+ uint32_t packed_result;
49+ float dummy = 0 .0f ; // Second value for x2 conversion
50+
51+ asm volatile (
52+ " {\n\t "
53+ " .reg .b8 byte0;\n\t "
54+ " .reg .b32 result;\n\t "
55+ " cvt.rn.satfinite.e2m1x2.f32 byte0, %1, %2;\n\t "
56+ " mov.b32 result, {byte0, 0, 0, 0};\n\t "
57+ " mov.b32 %0, result;\n\t "
58+ " }"
59+ : " =r" (packed_result)
60+ : " f" (input), " f" (dummy)
61+ );
62+
63+ // Extract the FP4 values
64+ cuda_result[0 ] = (packed_result >> 4 ) & 0xF ; // High nibble
65+ cuda_result[1 ] = packed_result & 0xF ; // Low nibble (dummy)
66+
67+ // Store debug info
68+ debug_info[0 ] = input;
69+ }
70+
71+ // Manual FP4 conversion matching PyTorch behavior
72+ uint8_t pytorch_style_fp4_convert (float val) {
73+ constexpr float F4_E2M1_MAX = 6 .0f ;
74+
75+ // Clamp to FP4 range
76+ val = std::fmax (-F4_E2M1_MAX, std::fmin (F4_E2M1_MAX, val));
77+
78+ if (val == 0 .0f ) return 0 ;
79+
80+ uint32_t bits = *reinterpret_cast <uint32_t *>(&val);
81+ uint32_t sign = (bits >> 31 ) & 1 ;
82+ int32_t exp = ((bits >> 23 ) & 0xFF ) - 127 ;
83+ uint32_t mantissa = bits & 0x7FFFFF ;
84+
85+ // Handle special cases
86+ if (exp < -2 ) return sign << 3 ; // Underflow to zero
87+
88+ // Handle denormals
89+ if (exp == -2 ) {
90+ // Can only represent ±0.5 as denormal in E2M1
91+ if (mantissa >= 0x400000 ) { // >= 0.5 in mantissa
92+ return (sign << 3 ) | 0x1 ; // Denormal 0.5
93+ }
94+ return sign << 3 ; // Zero
95+ }
96+
97+ // Normal numbers
98+ if (exp > 2 ) {
99+ // Overflow to max (±6.0)
100+ return (sign << 3 ) | 0x7 ;
101+ }
102+
103+ // Round mantissa to 1 bit
104+ uint32_t mantissa_bit = (mantissa >> 22 ) & 1 ;
105+ uint32_t round_bit = (mantissa >> 21 ) & 1 ;
106+
107+ // Round to nearest, ties to even
108+ if (round_bit && ((mantissa_bit == 1 ) || ((mantissa & 0x1FFFFF ) != 0 ))) {
109+ mantissa_bit++;
110+ if (mantissa_bit > 1 ) {
111+ mantissa_bit = 0 ;
112+ exp++;
113+ if (exp > 2 ) {
114+ return (sign << 3 ) | 0x7 ; // Overflow
115+ }
116+ }
117+ }
118+
119+ return (sign << 3 ) | ((exp + 1 ) << 1 ) | mantissa_bit;
120+ }
121+
122+ int main () {
123+ // Test the specific problematic bfloat16 values from the failing test
124+ std::vector<float > test_values = {
125+ 1 .171875f , // From index 1011
126+ -1 .171875f , // From index 4941
127+ -0 .585938f , // From index 8192
128+ 0 .585938f , // From index 28410
129+ 2 .5f , // Additional test
130+ -2 .5f , // Additional test
131+ 1 .25f , // Edge case
132+ -1 .25f , // Edge case
133+ 0 .75f , // Near 0.5
134+ -0 .75f // Near -0.5
135+ };
136+
137+ std::cout << " Testing NVFP4 quantization behavior with bfloat16 values\n " ;
138+ std::cout << " =========================================================\n\n " ;
139+
140+ for (float orig_val : test_values) {
141+ // Convert to bfloat16 and back
142+ float bf16_val = to_bfloat16_and_back (orig_val);
143+
144+ // Allocate device memory
145+ uint8_t * d_cuda_result;
146+ float * d_debug_info;
147+ cudaMalloc (&d_cuda_result, 2 * sizeof (uint8_t ));
148+ cudaMalloc (&d_debug_info, sizeof (float ));
149+
150+ // Run kernel
151+ test_single_value_kernel<<<1 , 1 >>> (bf16_val, d_cuda_result, d_debug_info);
152+ cudaDeviceSynchronize ();
153+
154+ // Get results
155+ uint8_t h_cuda_result[2 ];
156+ float h_debug_info;
157+ cudaMemcpy (h_cuda_result, d_cuda_result, 2 * sizeof (uint8_t ), cudaMemcpyDeviceToHost);
158+ cudaMemcpy (&h_debug_info, d_debug_info, sizeof (float ), cudaMemcpyDeviceToHost);
159+
160+ // Manual conversion
161+ uint8_t pytorch_result = pytorch_style_fp4_convert (bf16_val);
162+
163+ // Decode FP4 values back to float
164+ float cuda_decoded, pytorch_decoded;
165+ int cuda_sign, cuda_exp, cuda_mantissa;
166+ int pt_sign, pt_exp, pt_mantissa;
167+
168+ decode_fp4_e2m1 (h_cuda_result[0 ], cuda_decoded, cuda_sign, cuda_exp, cuda_mantissa);
169+ decode_fp4_e2m1 (pytorch_result, pytorch_decoded, pt_sign, pt_exp, pt_mantissa);
170+
171+ std::cout << std::fixed << std::setprecision (6 );
172+ std::cout << " Original value: " << orig_val << " → BF16: " << bf16_val << " \n " ;
173+ std::cout << " CUDA intrinsic result:\n " ;
174+ std::cout << " FP4 bits: 0x" << std::hex << (int )h_cuda_result[0 ] << std::dec
175+ << " (s=" << cuda_sign << " , e=" << cuda_exp << " , m=" << cuda_mantissa << " )\n " ;
176+ std::cout << " Decoded: " << cuda_decoded << " \n " ;
177+ std::cout << " Error: " << std::abs (bf16_val - cuda_decoded) << " \n " ;
178+
179+ std::cout << " PyTorch-style result:\n " ;
180+ std::cout << " FP4 bits: 0x" << std::hex << (int )pytorch_result << std::dec
181+ << " (s=" << pt_sign << " , e=" << pt_exp << " , m=" << pt_mantissa << " )\n " ;
182+ std::cout << " Decoded: " << pytorch_decoded << " \n " ;
183+ std::cout << " Error: " << std::abs (bf16_val - pytorch_decoded) << " \n " ;
184+
185+ if (h_cuda_result[0 ] != pytorch_result) {
186+ std::cout << " >>> MISMATCH! Difference: " << (int )h_cuda_result[0 ] - (int )pytorch_result << " \n " ;
187+ std::cout << " >>> CUDA chose: " << cuda_decoded << " (error=" << std::abs (bf16_val - cuda_decoded) << " )\n " ;
188+ std::cout << " >>> PyTorch chose: " << pytorch_decoded << " (error=" << std::abs (bf16_val - pytorch_decoded) << " )\n " ;
189+ }
190+ std::cout << " \n " ;
191+
192+ // Cleanup
193+ cudaFree (d_cuda_result);
194+ cudaFree (d_debug_info);
195+ }
196+
197+ // Test rounding behavior around critical values
198+ std::cout << " \n Detailed rounding analysis for 1.171875:\n " ;
199+ std::cout << " ==========================================\n " ;
200+ float val = 1 .171875f ;
201+ float bf16_val = to_bfloat16_and_back (val);
202+
203+ std::cout << " BF16 value: " << bf16_val << " \n " ;
204+ std::cout << " Possible FP4 E2M1 representations:\n " ;
205+
206+ // Show all nearby FP4 values
207+ for (uint8_t fp4 = 0 ; fp4 <= 15 ; fp4++) {
208+ float decoded;
209+ int sign, exp, mantissa;
210+ decode_fp4_e2m1 (fp4, decoded, sign, exp, mantissa);
211+ float error = std::abs (bf16_val - decoded);
212+
213+ if (error < 2 .0f ) { // Only show nearby values
214+ std::cout << " FP4=0x" << std::hex << (int )fp4 << std::dec
215+ << " → " << decoded
216+ << " (error=" << error << " )" ;
217+ if (fp4 == 4 ) std::cout << " <- PyTorch chooses this" ;
218+ if (fp4 == 5 ) std::cout << " <- CUDA intrinsic chooses this" ;
219+ std::cout << " \n " ;
220+ }
221+ }
222+
223+ // Analyze the tie-breaking
224+ std::cout << " \n The value 1.171875 is exactly between 1.0 and 1.5\n " ;
225+ std::cout << " Distance to 1.0: " << std::abs (1 .171875f - 1 .0f ) << " \n " ;
226+ std::cout << " Distance to 1.5: " << std::abs (1 .171875f - 1 .5f ) << " \n " ;
227+ std::cout << " This is a tie! The rounding rule determines which is chosen.\n " ;
228+ std::cout << " CUDA intrinsic appears to round up, PyTorch rounds to even.\n " ;
229+
230+ return 0 ;
231+ }
0 commit comments