Skip to content

Commit ce94788

Browse files
authored
add casting (#7)
1 parent ee1c000 commit ce94788

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ foreach(EXAMPLE_SOURCE ${EXAMPLE_SOURCES})
5151

5252
# CUDA properties provided by CMAKE
5353
set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
54-
set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_ARCHITECTURES 90a)
54+
set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_ARCHITECTURES 100)
5555

5656
# Convert the flags string into a list of flags
5757
separate_arguments(EXTRA_CUDA_FLAGS_LIST UNIX_COMMAND "${EXTRA_CUDA_FLAGS}")

examples/mx/num_to_ue8mo.cu

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <cuda_fp8.h>
2+
#include <stdio.h>
3+
4+
__global__ void convert_to_e8m0(float *in, __nv_fp8_storage_t *out) {
5+
const float input_val = in[0];
6+
printf("Device input value: %f\n", input_val);
7+
__nv_fp8_storage_t result =
8+
__nv_cvt_float_to_e8m0(input_val, __NV_SATFINITE, cudaRoundNearest);
9+
printf("Device output value (hex): 0x%02x, (decimal): %u\n",
10+
(unsigned char)result, (unsigned char)result);
11+
out[0] = result;
12+
}
13+
14+
int main() {
15+
float h_in = 1.0f / 448.0f;
16+
float *d_in;
17+
__nv_fp8_storage_t *d_out, h_out;
18+
19+
cudaMalloc(&d_in, sizeof(float));
20+
cudaMalloc(&d_out, sizeof(__nv_fp8_storage_t));
21+
22+
cudaError_t err =
23+
cudaMemcpy(d_in, &h_in, sizeof(float), cudaMemcpyHostToDevice);
24+
if (err != cudaSuccess) {
25+
printf("Memcpy error: %s\n", cudaGetErrorString(err));
26+
return 1;
27+
}
28+
29+
convert_to_e8m0<<<1, 1>>>(d_in, d_out);
30+
cudaDeviceSynchronize(); // Need this to see printf from kernel
31+
err = cudaGetLastError();
32+
if (err != cudaSuccess) {
33+
printf("Kernel error: %s\n", cudaGetErrorString(err));
34+
return 1;
35+
}
36+
37+
err = cudaMemcpy(&h_out, d_out, sizeof(__nv_fp8_storage_t),
38+
cudaMemcpyDeviceToHost);
39+
if (err != cudaSuccess) {
40+
printf("Memcpy error: %s\n", cudaGetErrorString(err));
41+
return 1;
42+
}
43+
44+
printf("Host input float: %f\n", h_in);
45+
printf("Host output e8m0 hex: 0x%02x, decimal: %u\n", (unsigned char)h_out,
46+
(unsigned char)h_out);
47+
printf("Host output e8m0 bits: ");
48+
for (int i = 7; i >= 0; i--) {
49+
printf("%d", (h_out >> i) & 0x1);
50+
}
51+
printf("\n");
52+
53+
cudaFree(d_in);
54+
cudaFree(d_out);
55+
return 0;
56+
}

0 commit comments

Comments
 (0)