Skip to content

Commit f3ca380

Browse files
committed
script
stack-info: PR: #8, branch: drisspg/stack/2
1 parent 4f72ffe commit f3ca380

File tree

3 files changed

+338
-2
lines changed

3 files changed

+338
-2
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ foreach(EXAMPLE_SOURCE ${EXAMPLE_SOURCES})
5353

5454
# CUDA properties provided by CMAKE
5555
set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
56-
set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_ARCHITECTURES 100)
56+
set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_ARCHITECTURES 100a)
5757

5858
# Convert the flags string into a list of flags
5959
separate_arguments(EXTRA_CUDA_FLAGS_LIST UNIX_COMMAND "${EXTRA_CUDA_FLAGS}")

examples/misc/atomic_max.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include <iostream>
33
#include <limits>
44

5-
__device__ __forceinline__ float atomicMaxFloatBAD(float *addr, float value) {
5+
__device__ __forceinline__ void atomicMaxFloatBAD(float *addr, float value) {
66
// source: https://stackoverflow.com/a/51549250
77
(value >= 0)
88
? __int_as_float(atomicMax((int *)addr, __float_as_int(value)))

examples/test_nvfp4_rounding.cu

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
// Compilation: nvcc -arch=sm_90 -std=c++17 test_nvfp4_rounding.cu -o test_nvfp4_rounding
2+
// For Godbolt: Add flags: -arch=sm_90 -std=c++17
3+
// Note: Requires SM 9.0+ for FP4 E2M1 intrinsics
4+
5+
#include <cuda_fp16.h>
6+
#include <cuda_bf16.h>
7+
#include <iostream>
8+
#include <iomanip>
9+
#include <vector>
10+
#include <cmath>
11+
#include <algorithm>
12+
13+
// FP4 E2M1 lookup table
14+
const float fp4_e2m1_lut[16] = {
15+
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
16+
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f
17+
};
18+
19+
// Simple test kernel to verify PTX instruction works
20+
__global__ void test_ptx_instruction() {
21+
// Test with known values
22+
float test_vals[4] = {1.0f, 2.0f, 3.0f, 4.0f};
23+
unsigned int results[4];
24+
25+
for (int i = 0; i < 4; i++) {
26+
asm volatile (
27+
"{\n\t"
28+
" .reg .b8 fp4_byte;\n\t"
29+
" cvt.rn.satfinite.e2m1x2.f32 fp4_byte, %1, %2;\n\t"
30+
" cvt.u32.u8 %0, fp4_byte;\n\t"
31+
"}\n\t"
32+
: "=r"(results[i])
33+
: "f"(test_vals[i]), "f"(0.0f)
34+
);
35+
}
36+
37+
// Print results (only thread 0)
38+
if (threadIdx.x == 0 && blockIdx.x == 0) {
39+
printf("PTX Instruction Test:\n");
40+
for (int i = 0; i < 4; i++) {
41+
uint8_t byte = results[i] & 0xFF;
42+
uint8_t low = byte & 0xF;
43+
uint8_t high = (byte >> 4) & 0xF;
44+
printf(" %.1f → byte: 0x%02x (low: 0x%x, high: 0x%x)\n",
45+
test_vals[i], byte, low, high);
46+
}
47+
printf("\n");
48+
}
49+
}
50+
51+
// Test kernel - correct PTX usage based on NVIDIA docs
52+
__global__ void test_fp4_conversion_kernel(
53+
float* inputs,
54+
uint8_t* fp4_outputs,
55+
uint8_t* raw_bytes,
56+
int count
57+
) {
58+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
59+
if (idx >= count) return;
60+
61+
float val1 = inputs[idx];
62+
float val2 = 0.0f; // dummy second value
63+
64+
// The cvt instruction outputs a single byte containing two FP4 values
65+
// We need to use inline PTX with proper register declarations
66+
unsigned int result = 0;
67+
68+
asm volatile (
69+
"{\n\t"
70+
" .reg .b8 fp4_byte;\n\t"
71+
" cvt.rn.satfinite.e2m1x2.f32 fp4_byte, %1, %2;\n\t"
72+
" cvt.u32.u8 %0, fp4_byte;\n\t" // Convert byte to 32-bit for output
73+
"}\n\t"
74+
: "=r"(result)
75+
: "f"(val1), "f"(val2)
76+
);
77+
78+
// Extract the byte
79+
uint8_t byte_result = result & 0xFF;
80+
raw_bytes[idx] = byte_result;
81+
82+
// The e2m1x2 format packs two FP4 values:
83+
// First float → ??? (need to determine which nibble)
84+
// Second float → ??? (need to determine which nibble)
85+
86+
// For now, store the full byte and we'll analyze it on the host
87+
fp4_outputs[idx] = byte_result;
88+
}
89+
90+
int main() {
91+
std::cout << "Testing CUDA FP4 E2M1 Rounding Behavior\n";
92+
std::cout << "=======================================\n\n";
93+
94+
// First, run a simple test to verify PTX instruction works
95+
test_ptx_instruction<<<1, 1>>>();
96+
cudaDeviceSynchronize();
97+
98+
// Print all possible E2M1 values
99+
std::cout << "All possible FP4 E2M1 values:\n";
100+
std::cout << "-----------------------------\n";
101+
std::cout << "FP4 Binary Value\n";
102+
std::cout << "--- ------ -----\n";
103+
104+
for (uint8_t i = 0; i <= 15; i++) {
105+
std::cout << "0x" << std::hex << (int)i << std::dec
106+
<< " 0b" << ((i>>3)&1) << ((i>>2)&1) << ((i>>1)&1) << (i&1)
107+
<< " " << std::setw(5) << fp4_e2m1_lut[i] << "\n";
108+
}
109+
std::cout << "\n";
110+
111+
// Test values - focusing on tie cases
112+
std::vector<float> test_values = {
113+
// Perfect tie cases
114+
0.75f, // Exactly between 0.5 and 1.0
115+
1.25f, // Exactly between 1.0 and 1.5
116+
1.75f, // Exactly between 1.5 and 2.0
117+
2.5f, // Exactly between 2.0 and 3.0
118+
3.5f, // Exactly between 3.0 and 4.0
119+
5.0f, // Exactly between 4.0 and 6.0
120+
// Negative ties
121+
-0.75f, -1.25f, -1.75f, -2.5f, -3.5f, -5.0f,
122+
// Non-ties for comparison
123+
0.6f, 1.1f, 2.2f, 2.8f
124+
};
125+
126+
// Allocate memory
127+
float* d_inputs;
128+
uint8_t* d_outputs;
129+
uint8_t* d_raw_bytes;
130+
cudaMalloc(&d_inputs, test_values.size() * sizeof(float));
131+
cudaMalloc(&d_outputs, test_values.size() * sizeof(uint8_t));
132+
cudaMalloc(&d_raw_bytes, test_values.size() * sizeof(uint8_t));
133+
134+
// Copy inputs
135+
cudaMemcpy(d_inputs, test_values.data(),
136+
test_values.size() * sizeof(float),
137+
cudaMemcpyHostToDevice);
138+
139+
// Run kernel
140+
int threads = 256;
141+
int blocks = (test_values.size() + threads - 1) / threads;
142+
test_fp4_conversion_kernel<<<blocks, threads>>>(d_inputs, d_outputs, d_raw_bytes, test_values.size());
143+
cudaDeviceSynchronize();
144+
145+
// Get results
146+
std::vector<uint8_t> h_outputs(test_values.size());
147+
std::vector<uint8_t> h_raw_bytes(test_values.size());
148+
cudaMemcpy(h_outputs.data(), d_outputs,
149+
h_outputs.size() * sizeof(uint8_t),
150+
cudaMemcpyDeviceToHost);
151+
cudaMemcpy(h_raw_bytes.data(), d_raw_bytes,
152+
h_raw_bytes.size() * sizeof(uint8_t),
153+
cudaMemcpyDeviceToHost);
154+
155+
// Debug: Show raw bytes and figure out nibble ordering
156+
std::cout << "\nDebug: Raw bytes from PTX instruction:\n";
157+
std::cout << "--------------------------------------\n";
158+
159+
for (size_t i = 0; i < 6 && i < test_values.size(); i++) {
160+
uint8_t raw = h_raw_bytes[i];
161+
uint8_t low = raw & 0xF;
162+
uint8_t high = (raw >> 4) & 0xF;
163+
164+
std::cout << "Input: " << std::setw(6) << test_values[i]
165+
<< " → Raw byte: 0x" << std::hex << std::setw(2) << std::setfill('0') << (int)raw << std::dec;
166+
167+
float decoded_low = fp4_e2m1_lut[low];
168+
float decoded_high = fp4_e2m1_lut[high];
169+
float error_low = std::abs(test_values[i] - decoded_low);
170+
float error_high = std::abs(test_values[i] - decoded_high);
171+
172+
std::cout << "\n Low nibble: 0x" << std::hex << (int)low << std::dec
173+
<< " = " << std::setw(5) << decoded_low
174+
<< " (error: " << error_low << ")";
175+
if (error_low < error_high) std::cout << " ← likely match";
176+
177+
std::cout << "\n High nibble: 0x" << std::hex << (int)high << std::dec
178+
<< " = " << std::setw(5) << decoded_high
179+
<< " (error: " << error_high << ")";
180+
if (error_high < error_low) std::cout << " ← likely match";
181+
182+
std::cout << std::setfill(' ') << "\n\n";
183+
}
184+
185+
// Determine which nibble to use based on errors
186+
std::cout << "Determining nibble assignment...\n";
187+
float total_error_low = 0, total_error_high = 0;
188+
int count_low_better = 0, count_high_better = 0;
189+
190+
for (size_t i = 0; i < test_values.size(); i++) {
191+
uint8_t raw = h_raw_bytes[i];
192+
uint8_t low = raw & 0xF;
193+
uint8_t high = (raw >> 4) & 0xF;
194+
195+
float error_low = std::abs(test_values[i] - fp4_e2m1_lut[low]);
196+
float error_high = std::abs(test_values[i] - fp4_e2m1_lut[high]);
197+
198+
total_error_low += error_low;
199+
total_error_high += error_high;
200+
201+
if (error_low < error_high) count_low_better++;
202+
else if (error_high < error_low) count_high_better++;
203+
}
204+
205+
bool use_low_nibble = (total_error_low <= total_error_high);
206+
std::cout << "Total error using low nibble: " << total_error_low << "\n";
207+
std::cout << "Total error using high nibble: " << total_error_high << "\n";
208+
std::cout << "Decision: Use " << (use_low_nibble ? "LOW" : "HIGH") << " nibble for first float\n\n";
209+
210+
// Update outputs based on decision
211+
for (size_t i = 0; i < h_outputs.size(); i++) {
212+
if (use_low_nibble) {
213+
h_outputs[i] = h_raw_bytes[i] & 0xF;
214+
} else {
215+
h_outputs[i] = (h_raw_bytes[i] >> 4) & 0xF;
216+
}
217+
}
218+
219+
// Analyze results
220+
std::cout << "Conversion Results:\n";
221+
std::cout << "==================\n";
222+
std::cout << std::setw(8) << "Input"
223+
<< std::setw(8) << "FP4"
224+
<< std::setw(10) << "Decoded"
225+
<< std::setw(10) << "Error"
226+
<< std::setw(20) << "Comment\n";
227+
228+
for (size_t i = 0; i < test_values.size(); i++) {
229+
float input = test_values[i];
230+
uint8_t fp4 = h_outputs[i];
231+
float decoded = fp4_e2m1_lut[fp4];
232+
float error = std::abs(input - decoded);
233+
234+
std::cout << std::fixed << std::setprecision(3);
235+
std::cout << std::setw(8) << input
236+
<< std::setw(8) << "0x" << std::hex << (int)fp4 << std::dec
237+
<< std::setw(10) << decoded
238+
<< std::setw(10) << error;
239+
240+
// Check for tie cases
241+
bool is_tie = false;
242+
float option1 = 0, option2 = 0;
243+
for (int j = 0; j < 16; j++) {
244+
for (int k = j+1; k < 16; k++) {
245+
if (std::abs(std::abs(input - fp4_e2m1_lut[j]) -
246+
std::abs(input - fp4_e2m1_lut[k])) < 1e-6f &&
247+
std::abs(input - fp4_e2m1_lut[j]) < 0.51f) {
248+
is_tie = true;
249+
option1 = fp4_e2m1_lut[j];
250+
option2 = fp4_e2m1_lut[k];
251+
if (option1 > option2) std::swap(option1, option2);
252+
}
253+
}
254+
}
255+
256+
if (is_tie) {
257+
std::cout << " (tie: " << option1 << " vs " << option2 << ")";
258+
}
259+
std::cout << "\n";
260+
}
261+
262+
// Detailed tie analysis
263+
std::cout << "\n\nTie-Breaking Analysis:\n";
264+
std::cout << "=====================\n";
265+
std::cout << "For round-to-nearest-even, ties should go to the value with even mantissa (m=0)\n\n";
266+
267+
std::cout << std::setw(8) << "Input"
268+
<< std::setw(15) << "Options"
269+
<< std::setw(10) << "Result"
270+
<< std::setw(10) << "Mantissa"
271+
<< std::setw(20) << "Round-to-even?\n";
272+
273+
// Analyze specific tie cases
274+
std::vector<float> tie_values = {0.75f, 1.25f, 1.75f, 2.5f, 3.5f, 5.0f};
275+
std::vector<std::pair<float, float>> tie_options = {
276+
{0.5f, 1.0f}, {1.0f, 1.5f}, {1.5f, 2.0f},
277+
{2.0f, 3.0f}, {3.0f, 4.0f}, {4.0f, 6.0f}
278+
};
279+
280+
for (size_t i = 0; i < tie_values.size(); i++) {
281+
// Find the result in our test
282+
for (size_t j = 0; j < test_values.size(); j++) {
283+
if (std::abs(test_values[j] - tie_values[i]) < 1e-6f) {
284+
uint8_t fp4 = h_outputs[j];
285+
float decoded = fp4_e2m1_lut[fp4];
286+
int mantissa = fp4 & 1;
287+
288+
std::cout << std::setw(8) << tie_values[i]
289+
<< std::setw(15) << tie_options[i].first
290+
<< " vs " << tie_options[i].second
291+
<< std::setw(10) << decoded
292+
<< std::setw(10) << "m=" << mantissa
293+
<< std::setw(20) << (mantissa == 0 ? "Yes ✓" : "No ✗")
294+
<< "\n";
295+
break;
296+
}
297+
}
298+
}
299+
300+
// Special focus on 2.5
301+
std::cout << "\n\n2.5 Detailed Analysis:\n";
302+
std::cout << "=====================\n";
303+
for (size_t i = 0; i < test_values.size(); i++) {
304+
if (std::abs(test_values[i] - 2.5f) < 1e-6f) {
305+
uint8_t fp4 = h_outputs[i];
306+
float decoded = fp4_e2m1_lut[fp4];
307+
308+
std::cout << "Input: 2.5\n";
309+
std::cout << "FP4 encoding: 0x" << std::hex << (int)fp4 << std::dec
310+
<< " (0b" << ((fp4>>3)&1) << ((fp4>>2)&1)
311+
<< ((fp4>>1)&1) << (fp4&1) << ")\n";
312+
std::cout << "Decoded value: " << decoded << "\n";
313+
std::cout << "Mantissa bit: " << (fp4 & 1) << "\n\n";
314+
315+
std::cout << "Options were:\n";
316+
std::cout << " 2.0 (0x4, m=0) - even mantissa\n";
317+
std::cout << " 3.0 (0x5, m=1) - odd mantissa\n\n";
318+
319+
if (fp4 == 0x4) {
320+
std::cout << "✓ Correctly rounded to even (2.0)\n";
321+
} else if (fp4 == 0x5) {
322+
std::cout << "✗ Rounded to odd (3.0) - NOT round-to-nearest-even\n";
323+
} else {
324+
std::cout << "! Unexpected result\n";
325+
}
326+
break;
327+
}
328+
}
329+
330+
// Cleanup
331+
cudaFree(d_inputs);
332+
cudaFree(d_outputs);
333+
cudaFree(d_raw_bytes);
334+
335+
return 0;
336+
}

0 commit comments

Comments
 (0)