|
| 1 | +// Compilation: nvcc -arch=sm_90 -std=c++17 test_nvfp4_rounding.cu -o test_nvfp4_rounding |
| 2 | +// For Godbolt: Add flags: -arch=sm_90 -std=c++17 |
| 3 | +// Note: Requires SM 9.0+ for FP4 E2M1 intrinsics |
| 4 | + |
| 5 | +#include <cuda_fp16.h> |
| 6 | +#include <cuda_bf16.h> |
| 7 | +#include <iostream> |
| 8 | +#include <iomanip> |
| 9 | +#include <vector> |
| 10 | +#include <cmath> |
| 11 | +#include <algorithm> |
| 12 | + |
| 13 | +// FP4 E2M1 lookup table |
| 14 | +const float fp4_e2m1_lut[16] = { |
| 15 | + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, |
| 16 | + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f |
| 17 | +}; |
| 18 | + |
| 19 | +// Simple test kernel to verify PTX instruction works |
| 20 | +__global__ void test_ptx_instruction() { |
| 21 | + // Test with known values |
| 22 | + float test_vals[4] = {1.0f, 2.0f, 3.0f, 4.0f}; |
| 23 | + unsigned int results[4]; |
| 24 | + |
| 25 | + for (int i = 0; i < 4; i++) { |
| 26 | + asm volatile ( |
| 27 | + "{\n\t" |
| 28 | + " .reg .b8 fp4_byte;\n\t" |
| 29 | + " cvt.rn.satfinite.e2m1x2.f32 fp4_byte, %1, %2;\n\t" |
| 30 | + " cvt.u32.u8 %0, fp4_byte;\n\t" |
| 31 | + "}\n\t" |
| 32 | + : "=r"(results[i]) |
| 33 | + : "f"(test_vals[i]), "f"(0.0f) |
| 34 | + ); |
| 35 | + } |
| 36 | + |
| 37 | + // Print results (only thread 0) |
| 38 | + if (threadIdx.x == 0 && blockIdx.x == 0) { |
| 39 | + printf("PTX Instruction Test:\n"); |
| 40 | + for (int i = 0; i < 4; i++) { |
| 41 | + uint8_t byte = results[i] & 0xFF; |
| 42 | + uint8_t low = byte & 0xF; |
| 43 | + uint8_t high = (byte >> 4) & 0xF; |
| 44 | + printf(" %.1f → byte: 0x%02x (low: 0x%x, high: 0x%x)\n", |
| 45 | + test_vals[i], byte, low, high); |
| 46 | + } |
| 47 | + printf("\n"); |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +// Test kernel - correct PTX usage based on NVIDIA docs |
| 52 | +__global__ void test_fp4_conversion_kernel( |
| 53 | + float* inputs, |
| 54 | + uint8_t* fp4_outputs, |
| 55 | + uint8_t* raw_bytes, |
| 56 | + int count |
| 57 | +) { |
| 58 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 59 | + if (idx >= count) return; |
| 60 | + |
| 61 | + float val1 = inputs[idx]; |
| 62 | + float val2 = 0.0f; // dummy second value |
| 63 | + |
| 64 | + // The cvt instruction outputs a single byte containing two FP4 values |
| 65 | + // We need to use inline PTX with proper register declarations |
| 66 | + unsigned int result = 0; |
| 67 | + |
| 68 | + asm volatile ( |
| 69 | + "{\n\t" |
| 70 | + " .reg .b8 fp4_byte;\n\t" |
| 71 | + " cvt.rn.satfinite.e2m1x2.f32 fp4_byte, %1, %2;\n\t" |
| 72 | + " cvt.u32.u8 %0, fp4_byte;\n\t" // Convert byte to 32-bit for output |
| 73 | + "}\n\t" |
| 74 | + : "=r"(result) |
| 75 | + : "f"(val1), "f"(val2) |
| 76 | + ); |
| 77 | + |
| 78 | + // Extract the byte |
| 79 | + uint8_t byte_result = result & 0xFF; |
| 80 | + raw_bytes[idx] = byte_result; |
| 81 | + |
| 82 | + // The e2m1x2 format packs two FP4 values: |
| 83 | + // First float → ??? (need to determine which nibble) |
| 84 | + // Second float → ??? (need to determine which nibble) |
| 85 | + |
| 86 | + // For now, store the full byte and we'll analyze it on the host |
| 87 | + fp4_outputs[idx] = byte_result; |
| 88 | +} |
| 89 | + |
| 90 | +int main() { |
| 91 | + std::cout << "Testing CUDA FP4 E2M1 Rounding Behavior\n"; |
| 92 | + std::cout << "=======================================\n\n"; |
| 93 | + |
| 94 | + // First, run a simple test to verify PTX instruction works |
| 95 | + test_ptx_instruction<<<1, 1>>>(); |
| 96 | + cudaDeviceSynchronize(); |
| 97 | + |
| 98 | + // Print all possible E2M1 values |
| 99 | + std::cout << "All possible FP4 E2M1 values:\n"; |
| 100 | + std::cout << "-----------------------------\n"; |
| 101 | + std::cout << "FP4 Binary Value\n"; |
| 102 | + std::cout << "--- ------ -----\n"; |
| 103 | + |
| 104 | + for (uint8_t i = 0; i <= 15; i++) { |
| 105 | + std::cout << "0x" << std::hex << (int)i << std::dec |
| 106 | + << " 0b" << ((i>>3)&1) << ((i>>2)&1) << ((i>>1)&1) << (i&1) |
| 107 | + << " " << std::setw(5) << fp4_e2m1_lut[i] << "\n"; |
| 108 | + } |
| 109 | + std::cout << "\n"; |
| 110 | + |
| 111 | + // Test values - focusing on tie cases |
| 112 | + std::vector<float> test_values = { |
| 113 | + // Perfect tie cases |
| 114 | + 0.75f, // Exactly between 0.5 and 1.0 |
| 115 | + 1.25f, // Exactly between 1.0 and 1.5 |
| 116 | + 1.75f, // Exactly between 1.5 and 2.0 |
| 117 | + 2.5f, // Exactly between 2.0 and 3.0 |
| 118 | + 3.5f, // Exactly between 3.0 and 4.0 |
| 119 | + 5.0f, // Exactly between 4.0 and 6.0 |
| 120 | + // Negative ties |
| 121 | + -0.75f, -1.25f, -1.75f, -2.5f, -3.5f, -5.0f, |
| 122 | + // Non-ties for comparison |
| 123 | + 0.6f, 1.1f, 2.2f, 2.8f |
| 124 | + }; |
| 125 | + |
| 126 | + // Allocate memory |
| 127 | + float* d_inputs; |
| 128 | + uint8_t* d_outputs; |
| 129 | + uint8_t* d_raw_bytes; |
| 130 | + cudaMalloc(&d_inputs, test_values.size() * sizeof(float)); |
| 131 | + cudaMalloc(&d_outputs, test_values.size() * sizeof(uint8_t)); |
| 132 | + cudaMalloc(&d_raw_bytes, test_values.size() * sizeof(uint8_t)); |
| 133 | + |
| 134 | + // Copy inputs |
| 135 | + cudaMemcpy(d_inputs, test_values.data(), |
| 136 | + test_values.size() * sizeof(float), |
| 137 | + cudaMemcpyHostToDevice); |
| 138 | + |
| 139 | + // Run kernel |
| 140 | + int threads = 256; |
| 141 | + int blocks = (test_values.size() + threads - 1) / threads; |
| 142 | + test_fp4_conversion_kernel<<<blocks, threads>>>(d_inputs, d_outputs, d_raw_bytes, test_values.size()); |
| 143 | + cudaDeviceSynchronize(); |
| 144 | + |
| 145 | + // Get results |
| 146 | + std::vector<uint8_t> h_outputs(test_values.size()); |
| 147 | + std::vector<uint8_t> h_raw_bytes(test_values.size()); |
| 148 | + cudaMemcpy(h_outputs.data(), d_outputs, |
| 149 | + h_outputs.size() * sizeof(uint8_t), |
| 150 | + cudaMemcpyDeviceToHost); |
| 151 | + cudaMemcpy(h_raw_bytes.data(), d_raw_bytes, |
| 152 | + h_raw_bytes.size() * sizeof(uint8_t), |
| 153 | + cudaMemcpyDeviceToHost); |
| 154 | + |
| 155 | + // Debug: Show raw bytes and figure out nibble ordering |
| 156 | + std::cout << "\nDebug: Raw bytes from PTX instruction:\n"; |
| 157 | + std::cout << "--------------------------------------\n"; |
| 158 | + |
| 159 | + for (size_t i = 0; i < 6 && i < test_values.size(); i++) { |
| 160 | + uint8_t raw = h_raw_bytes[i]; |
| 161 | + uint8_t low = raw & 0xF; |
| 162 | + uint8_t high = (raw >> 4) & 0xF; |
| 163 | + |
| 164 | + std::cout << "Input: " << std::setw(6) << test_values[i] |
| 165 | + << " → Raw byte: 0x" << std::hex << std::setw(2) << std::setfill('0') << (int)raw << std::dec; |
| 166 | + |
| 167 | + float decoded_low = fp4_e2m1_lut[low]; |
| 168 | + float decoded_high = fp4_e2m1_lut[high]; |
| 169 | + float error_low = std::abs(test_values[i] - decoded_low); |
| 170 | + float error_high = std::abs(test_values[i] - decoded_high); |
| 171 | + |
| 172 | + std::cout << "\n Low nibble: 0x" << std::hex << (int)low << std::dec |
| 173 | + << " = " << std::setw(5) << decoded_low |
| 174 | + << " (error: " << error_low << ")"; |
| 175 | + if (error_low < error_high) std::cout << " ← likely match"; |
| 176 | + |
| 177 | + std::cout << "\n High nibble: 0x" << std::hex << (int)high << std::dec |
| 178 | + << " = " << std::setw(5) << decoded_high |
| 179 | + << " (error: " << error_high << ")"; |
| 180 | + if (error_high < error_low) std::cout << " ← likely match"; |
| 181 | + |
| 182 | + std::cout << std::setfill(' ') << "\n\n"; |
| 183 | + } |
| 184 | + |
| 185 | + // Determine which nibble to use based on errors |
| 186 | + std::cout << "Determining nibble assignment...\n"; |
| 187 | + float total_error_low = 0, total_error_high = 0; |
| 188 | + int count_low_better = 0, count_high_better = 0; |
| 189 | + |
| 190 | + for (size_t i = 0; i < test_values.size(); i++) { |
| 191 | + uint8_t raw = h_raw_bytes[i]; |
| 192 | + uint8_t low = raw & 0xF; |
| 193 | + uint8_t high = (raw >> 4) & 0xF; |
| 194 | + |
| 195 | + float error_low = std::abs(test_values[i] - fp4_e2m1_lut[low]); |
| 196 | + float error_high = std::abs(test_values[i] - fp4_e2m1_lut[high]); |
| 197 | + |
| 198 | + total_error_low += error_low; |
| 199 | + total_error_high += error_high; |
| 200 | + |
| 201 | + if (error_low < error_high) count_low_better++; |
| 202 | + else if (error_high < error_low) count_high_better++; |
| 203 | + } |
| 204 | + |
| 205 | + bool use_low_nibble = (total_error_low <= total_error_high); |
| 206 | + std::cout << "Total error using low nibble: " << total_error_low << "\n"; |
| 207 | + std::cout << "Total error using high nibble: " << total_error_high << "\n"; |
| 208 | + std::cout << "Decision: Use " << (use_low_nibble ? "LOW" : "HIGH") << " nibble for first float\n\n"; |
| 209 | + |
| 210 | + // Update outputs based on decision |
| 211 | + for (size_t i = 0; i < h_outputs.size(); i++) { |
| 212 | + if (use_low_nibble) { |
| 213 | + h_outputs[i] = h_raw_bytes[i] & 0xF; |
| 214 | + } else { |
| 215 | + h_outputs[i] = (h_raw_bytes[i] >> 4) & 0xF; |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + // Analyze results |
| 220 | + std::cout << "Conversion Results:\n"; |
| 221 | + std::cout << "==================\n"; |
| 222 | + std::cout << std::setw(8) << "Input" |
| 223 | + << std::setw(8) << "FP4" |
| 224 | + << std::setw(10) << "Decoded" |
| 225 | + << std::setw(10) << "Error" |
| 226 | + << std::setw(20) << "Comment\n"; |
| 227 | + |
| 228 | + for (size_t i = 0; i < test_values.size(); i++) { |
| 229 | + float input = test_values[i]; |
| 230 | + uint8_t fp4 = h_outputs[i]; |
| 231 | + float decoded = fp4_e2m1_lut[fp4]; |
| 232 | + float error = std::abs(input - decoded); |
| 233 | + |
| 234 | + std::cout << std::fixed << std::setprecision(3); |
| 235 | + std::cout << std::setw(8) << input |
| 236 | + << std::setw(8) << "0x" << std::hex << (int)fp4 << std::dec |
| 237 | + << std::setw(10) << decoded |
| 238 | + << std::setw(10) << error; |
| 239 | + |
| 240 | + // Check for tie cases |
| 241 | + bool is_tie = false; |
| 242 | + float option1 = 0, option2 = 0; |
| 243 | + for (int j = 0; j < 16; j++) { |
| 244 | + for (int k = j+1; k < 16; k++) { |
| 245 | + if (std::abs(std::abs(input - fp4_e2m1_lut[j]) - |
| 246 | + std::abs(input - fp4_e2m1_lut[k])) < 1e-6f && |
| 247 | + std::abs(input - fp4_e2m1_lut[j]) < 0.51f) { |
| 248 | + is_tie = true; |
| 249 | + option1 = fp4_e2m1_lut[j]; |
| 250 | + option2 = fp4_e2m1_lut[k]; |
| 251 | + if (option1 > option2) std::swap(option1, option2); |
| 252 | + } |
| 253 | + } |
| 254 | + } |
| 255 | + |
| 256 | + if (is_tie) { |
| 257 | + std::cout << " (tie: " << option1 << " vs " << option2 << ")"; |
| 258 | + } |
| 259 | + std::cout << "\n"; |
| 260 | + } |
| 261 | + |
| 262 | + // Detailed tie analysis |
| 263 | + std::cout << "\n\nTie-Breaking Analysis:\n"; |
| 264 | + std::cout << "=====================\n"; |
| 265 | + std::cout << "For round-to-nearest-even, ties should go to the value with even mantissa (m=0)\n\n"; |
| 266 | + |
| 267 | + std::cout << std::setw(8) << "Input" |
| 268 | + << std::setw(15) << "Options" |
| 269 | + << std::setw(10) << "Result" |
| 270 | + << std::setw(10) << "Mantissa" |
| 271 | + << std::setw(20) << "Round-to-even?\n"; |
| 272 | + |
| 273 | + // Analyze specific tie cases |
| 274 | + std::vector<float> tie_values = {0.75f, 1.25f, 1.75f, 2.5f, 3.5f, 5.0f}; |
| 275 | + std::vector<std::pair<float, float>> tie_options = { |
| 276 | + {0.5f, 1.0f}, {1.0f, 1.5f}, {1.5f, 2.0f}, |
| 277 | + {2.0f, 3.0f}, {3.0f, 4.0f}, {4.0f, 6.0f} |
| 278 | + }; |
| 279 | + |
| 280 | + for (size_t i = 0; i < tie_values.size(); i++) { |
| 281 | + // Find the result in our test |
| 282 | + for (size_t j = 0; j < test_values.size(); j++) { |
| 283 | + if (std::abs(test_values[j] - tie_values[i]) < 1e-6f) { |
| 284 | + uint8_t fp4 = h_outputs[j]; |
| 285 | + float decoded = fp4_e2m1_lut[fp4]; |
| 286 | + int mantissa = fp4 & 1; |
| 287 | + |
| 288 | + std::cout << std::setw(8) << tie_values[i] |
| 289 | + << std::setw(15) << tie_options[i].first |
| 290 | + << " vs " << tie_options[i].second |
| 291 | + << std::setw(10) << decoded |
| 292 | + << std::setw(10) << "m=" << mantissa |
| 293 | + << std::setw(20) << (mantissa == 0 ? "Yes ✓" : "No ✗") |
| 294 | + << "\n"; |
| 295 | + break; |
| 296 | + } |
| 297 | + } |
| 298 | + } |
| 299 | + |
| 300 | + // Special focus on 2.5 |
| 301 | + std::cout << "\n\n2.5 Detailed Analysis:\n"; |
| 302 | + std::cout << "=====================\n"; |
| 303 | + for (size_t i = 0; i < test_values.size(); i++) { |
| 304 | + if (std::abs(test_values[i] - 2.5f) < 1e-6f) { |
| 305 | + uint8_t fp4 = h_outputs[i]; |
| 306 | + float decoded = fp4_e2m1_lut[fp4]; |
| 307 | + |
| 308 | + std::cout << "Input: 2.5\n"; |
| 309 | + std::cout << "FP4 encoding: 0x" << std::hex << (int)fp4 << std::dec |
| 310 | + << " (0b" << ((fp4>>3)&1) << ((fp4>>2)&1) |
| 311 | + << ((fp4>>1)&1) << (fp4&1) << ")\n"; |
| 312 | + std::cout << "Decoded value: " << decoded << "\n"; |
| 313 | + std::cout << "Mantissa bit: " << (fp4 & 1) << "\n\n"; |
| 314 | + |
| 315 | + std::cout << "Options were:\n"; |
| 316 | + std::cout << " 2.0 (0x4, m=0) - even mantissa\n"; |
| 317 | + std::cout << " 3.0 (0x5, m=1) - odd mantissa\n\n"; |
| 318 | + |
| 319 | + if (fp4 == 0x4) { |
| 320 | + std::cout << "✓ Correctly rounded to even (2.0)\n"; |
| 321 | + } else if (fp4 == 0x5) { |
| 322 | + std::cout << "✗ Rounded to odd (3.0) - NOT round-to-nearest-even\n"; |
| 323 | + } else { |
| 324 | + std::cout << "! Unexpected result\n"; |
| 325 | + } |
| 326 | + break; |
| 327 | + } |
| 328 | + } |
| 329 | + |
| 330 | + // Cleanup |
| 331 | + cudaFree(d_inputs); |
| 332 | + cudaFree(d_outputs); |
| 333 | + cudaFree(d_raw_bytes); |
| 334 | + |
| 335 | + return 0; |
| 336 | +} |
0 commit comments