Skip to content

Commit b94abfa

Browse files
committed
script
stack-info: PR: #8, branch: drisspg/stack/2
1 parent 4f72ffe commit b94abfa

File tree

1 file changed

+231
-0
lines changed

1 file changed

+231
-0
lines changed

examples/test_nvfp4_rounding.cu

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// Compilation: nvcc -arch=sm_90 -std=c++17 test_nvfp4_rounding.cu -o test_nvfp4_rounding
2+
// For Godbolt: Add flags: -arch=sm_90 -std=c++17
3+
// Note: Requires SM 9.0+ for FP4 E2M1 intrinsics
4+
5+
#include <cuda_fp16.h>
6+
#include <cuda_bf16.h>
7+
#include <iostream>
8+
#include <vector>
9+
#include <cmath>
10+
#include <iomanip>
11+
#include <cstring>
12+
13+
// Convert float to bfloat16 and back
14+
float to_bfloat16_and_back(float val) {
15+
uint32_t bits = *reinterpret_cast<uint32_t*>(&val);
16+
bits = (bits + 0x8000) & 0xFFFF0000; // Round to bfloat16
17+
return *reinterpret_cast<float*>(&bits);
18+
}
19+
20+
// FP4 E2M1 format decoder
21+
void decode_fp4_e2m1(uint8_t fp4, float& value, int& sign, int& exp, int& mantissa) {
22+
sign = (fp4 >> 3) & 1;
23+
exp = (fp4 >> 1) & 3;
24+
mantissa = fp4 & 1;
25+
26+
// E2M1 decoding with bias=1
27+
if (exp == 0) {
28+
// Denormal or zero
29+
if (mantissa == 0) {
30+
value = 0.0f;
31+
} else {
32+
value = (sign ? -1.0f : 1.0f) * 0.5f; // 2^(-1) * 0.5
33+
}
34+
} else {
35+
// Normal number: (-1)^s * 2^(e-1) * (1 + m/2)
36+
float mantissa_val = 1.0f + mantissa * 0.5f;
37+
value = (sign ? -1.0f : 1.0f) * std::pow(2.0f, exp - 1) * mantissa_val;
38+
}
39+
}
40+
41+
// Test kernel for single value
42+
__global__ void test_single_value_kernel(
43+
float input,
44+
uint8_t* cuda_result,
45+
float* debug_info
46+
) {
47+
// CUDA intrinsic conversion
48+
uint32_t packed_result;
49+
float dummy = 0.0f; // Second value for x2 conversion
50+
51+
asm volatile (
52+
"{\n\t"
53+
".reg .b8 byte0;\n\t"
54+
".reg .b32 result;\n\t"
55+
"cvt.rn.satfinite.e2m1x2.f32 byte0, %1, %2;\n\t"
56+
"mov.b32 result, {byte0, 0, 0, 0};\n\t"
57+
"mov.b32 %0, result;\n\t"
58+
"}"
59+
: "=r"(packed_result)
60+
: "f"(input), "f"(dummy)
61+
);
62+
63+
// Extract the FP4 values
64+
cuda_result[0] = (packed_result >> 4) & 0xF; // High nibble
65+
cuda_result[1] = packed_result & 0xF; // Low nibble (dummy)
66+
67+
// Store debug info
68+
debug_info[0] = input;
69+
}
70+
71+
// Manual FP4 conversion matching PyTorch behavior
72+
uint8_t pytorch_style_fp4_convert(float val) {
73+
constexpr float F4_E2M1_MAX = 6.0f;
74+
75+
// Clamp to FP4 range
76+
val = std::fmax(-F4_E2M1_MAX, std::fmin(F4_E2M1_MAX, val));
77+
78+
if (val == 0.0f) return 0;
79+
80+
uint32_t bits = *reinterpret_cast<uint32_t*>(&val);
81+
uint32_t sign = (bits >> 31) & 1;
82+
int32_t exp = ((bits >> 23) & 0xFF) - 127;
83+
uint32_t mantissa = bits & 0x7FFFFF;
84+
85+
// Handle special cases
86+
if (exp < -2) return sign << 3; // Underflow to zero
87+
88+
// Handle denormals
89+
if (exp == -2) {
90+
// Can only represent ±0.5 as denormal in E2M1
91+
if (mantissa >= 0x400000) { // >= 0.5 in mantissa
92+
return (sign << 3) | 0x1; // Denormal 0.5
93+
}
94+
return sign << 3; // Zero
95+
}
96+
97+
// Normal numbers
98+
if (exp > 2) {
99+
// Overflow to max (±6.0)
100+
return (sign << 3) | 0x7;
101+
}
102+
103+
// Round mantissa to 1 bit
104+
uint32_t mantissa_bit = (mantissa >> 22) & 1;
105+
uint32_t round_bit = (mantissa >> 21) & 1;
106+
107+
// Round to nearest, ties to even
108+
if (round_bit && ((mantissa_bit == 1) || ((mantissa & 0x1FFFFF) != 0))) {
109+
mantissa_bit++;
110+
if (mantissa_bit > 1) {
111+
mantissa_bit = 0;
112+
exp++;
113+
if (exp > 2) {
114+
return (sign << 3) | 0x7; // Overflow
115+
}
116+
}
117+
}
118+
119+
return (sign << 3) | ((exp + 1) << 1) | mantissa_bit;
120+
}
121+
122+
int main() {
123+
// Test the specific problematic bfloat16 values from the failing test
124+
std::vector<float> test_values = {
125+
1.171875f, // From index 1011
126+
-1.171875f, // From index 4941
127+
-0.585938f, // From index 8192
128+
0.585938f, // From index 28410
129+
2.5f, // Additional test
130+
-2.5f, // Additional test
131+
1.25f, // Edge case
132+
-1.25f, // Edge case
133+
0.75f, // Near 0.5
134+
-0.75f // Near -0.5
135+
};
136+
137+
std::cout << "Testing NVFP4 quantization behavior with bfloat16 values\n";
138+
std::cout << "=========================================================\n\n";
139+
140+
for (float orig_val : test_values) {
141+
// Convert to bfloat16 and back
142+
float bf16_val = to_bfloat16_and_back(orig_val);
143+
144+
// Allocate device memory
145+
uint8_t* d_cuda_result;
146+
float* d_debug_info;
147+
cudaMalloc(&d_cuda_result, 2 * sizeof(uint8_t));
148+
cudaMalloc(&d_debug_info, sizeof(float));
149+
150+
// Run kernel
151+
test_single_value_kernel<<<1, 1>>>(bf16_val, d_cuda_result, d_debug_info);
152+
cudaDeviceSynchronize();
153+
154+
// Get results
155+
uint8_t h_cuda_result[2];
156+
float h_debug_info;
157+
cudaMemcpy(h_cuda_result, d_cuda_result, 2 * sizeof(uint8_t), cudaMemcpyDeviceToHost);
158+
cudaMemcpy(&h_debug_info, d_debug_info, sizeof(float), cudaMemcpyDeviceToHost);
159+
160+
// Manual conversion
161+
uint8_t pytorch_result = pytorch_style_fp4_convert(bf16_val);
162+
163+
// Decode FP4 values back to float
164+
float cuda_decoded, pytorch_decoded;
165+
int cuda_sign, cuda_exp, cuda_mantissa;
166+
int pt_sign, pt_exp, pt_mantissa;
167+
168+
decode_fp4_e2m1(h_cuda_result[0], cuda_decoded, cuda_sign, cuda_exp, cuda_mantissa);
169+
decode_fp4_e2m1(pytorch_result, pytorch_decoded, pt_sign, pt_exp, pt_mantissa);
170+
171+
std::cout << std::fixed << std::setprecision(6);
172+
std::cout << "Original value: " << orig_val << " → BF16: " << bf16_val << "\n";
173+
std::cout << " CUDA intrinsic result:\n";
174+
std::cout << " FP4 bits: 0x" << std::hex << (int)h_cuda_result[0] << std::dec
175+
<< " (s=" << cuda_sign << ", e=" << cuda_exp << ", m=" << cuda_mantissa << ")\n";
176+
std::cout << " Decoded: " << cuda_decoded << "\n";
177+
std::cout << " Error: " << std::abs(bf16_val - cuda_decoded) << "\n";
178+
179+
std::cout << " PyTorch-style result:\n";
180+
std::cout << " FP4 bits: 0x" << std::hex << (int)pytorch_result << std::dec
181+
<< " (s=" << pt_sign << ", e=" << pt_exp << ", m=" << pt_mantissa << ")\n";
182+
std::cout << " Decoded: " << pytorch_decoded << "\n";
183+
std::cout << " Error: " << std::abs(bf16_val - pytorch_decoded) << "\n";
184+
185+
if (h_cuda_result[0] != pytorch_result) {
186+
std::cout << " >>> MISMATCH! Difference: " << (int)h_cuda_result[0] - (int)pytorch_result << "\n";
187+
std::cout << " >>> CUDA chose: " << cuda_decoded << " (error=" << std::abs(bf16_val - cuda_decoded) << ")\n";
188+
std::cout << " >>> PyTorch chose: " << pytorch_decoded << " (error=" << std::abs(bf16_val - pytorch_decoded) << ")\n";
189+
}
190+
std::cout << "\n";
191+
192+
// Cleanup
193+
cudaFree(d_cuda_result);
194+
cudaFree(d_debug_info);
195+
}
196+
197+
// Test rounding behavior around critical values
198+
std::cout << "\nDetailed rounding analysis for 1.171875:\n";
199+
std::cout << "==========================================\n";
200+
float val = 1.171875f;
201+
float bf16_val = to_bfloat16_and_back(val);
202+
203+
std::cout << "BF16 value: " << bf16_val << "\n";
204+
std::cout << "Possible FP4 E2M1 representations:\n";
205+
206+
// Show all nearby FP4 values
207+
for (uint8_t fp4 = 0; fp4 <= 15; fp4++) {
208+
float decoded;
209+
int sign, exp, mantissa;
210+
decode_fp4_e2m1(fp4, decoded, sign, exp, mantissa);
211+
float error = std::abs(bf16_val - decoded);
212+
213+
if (error < 2.0f) { // Only show nearby values
214+
std::cout << " FP4=0x" << std::hex << (int)fp4 << std::dec
215+
<< "" << decoded
216+
<< " (error=" << error << ")";
217+
if (fp4 == 4) std::cout << " <- PyTorch chooses this";
218+
if (fp4 == 5) std::cout << " <- CUDA intrinsic chooses this";
219+
std::cout << "\n";
220+
}
221+
}
222+
223+
// Analyze the tie-breaking
224+
std::cout << "\nThe value 1.171875 is exactly between 1.0 and 1.5\n";
225+
std::cout << "Distance to 1.0: " << std::abs(1.171875f - 1.0f) << "\n";
226+
std::cout << "Distance to 1.5: " << std::abs(1.171875f - 1.5f) << "\n";
227+
std::cout << "This is a tie! The rounding rule determines which is chosen.\n";
228+
std::cout << "CUDA intrinsic appears to round up, PyTorch rounds to even.\n";
229+
230+
return 0;
231+
}

0 commit comments

Comments
 (0)