Skip to content

Commit 3557235

Browse files
authored
Edit float16 doc (#5851)
* Add survey of support of half in different CUDA versions * small fix
1 parent 7300655 commit 3557235

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

doc/design/float16.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,51 @@ The goal of float16 is to serve as a key for the executor to find and run the co
2828
- [Eigen](https://github.com/RLovelett/eigen) >= 3.3 supports float16 calculation on both GPU and CPU using the `Eigen::half` class. It is mostly useful for Nvidia GPUs because of the overloaded arithmetic operators using cuda intrinsics. It falls back to using software emulation on CPU for calculation and there is no special treatment to ARM processors.
2929
- [ARM compute library](https://github.com/ARM-software/ComputeLibrary) >= 17.02.01 supports NEON FP16 kernels (requires ARMv8.2-A CPU).
3030

31+
### CUDA version issue
32+
There are currently three versions of CUDA that supports `__half` data type, namely, CUDA 7.5, 8.0, and 9.0.
33+
CUDA 7.5 and 8.0 define `__half` as a simple struct that has a `uint16_t` data (see [`cuda_fp16.h`](https://github.com/ptillet/isaac/blob/9212ab5a3ddbe48f30ef373f9c1fb546804c7a8c/include/isaac/external/CUDA/cuda_fp16.h)) as follows:
34+
```
35+
typedef struct __align__(2) {
36+
unsigned short x;
37+
} __half;
38+
39+
typedef __half half;
40+
```
41+
This struct does not define any overloaded arithmetic operators. So you have to directly use `__hadd` instead of `+` to correctly add two half types:
42+
```
43+
__global__ void Add() {
44+
half a, b, c;
45+
c = __hadd(a, b); // correct
46+
c = a + b; // compiler error: no operator "+" matches these operands
47+
}
48+
```
49+
CUDA 9.0 provides a major update to the half data type. The related code can be found in the updated [`cuda_fp16.h`](https://github.com/ptillet/isaac/blob/master/include/isaac/external/CUDA/cuda_fp16.h) and the newly added [`cuda_fp16.hpp`](https://github.com/ptillet/isaac/blob/master/include/isaac/external/CUDA/cuda_fp16.hpp).
50+
51+
Essentially, CUDA 9.0 renames the original `__half` type in 7.5 and 8.0 as `__half_raw`, and defines a new `__half` class type that has constructors, conversion operators, and also provides overloaded arithmetic operators such as follows:
52+
```
53+
typedef struct __CUDA_ALIGN__(2) {
54+
unsigned short x;
55+
} __half_raw;
56+
57+
58+
struct __CUDA_ALIGN__(2) __half {
59+
protected:
60+
unsigned short __x;
61+
public:
62+
// constructors and conversion operators from/to
63+
// __half_raw and other built-in data types
64+
}
65+
66+
typedef __half half;
67+
68+
__device__ __forceinline__
69+
__half operator+(const __half &lh, const __half &rh) {
70+
return __hadd(lh, rh);
71+
}
72+
73+
// Other overloaded operators
74+
```
75+
This new design makes `c = a + b` work correctly for CUDA half data type.
3176

3277
## Implementation
3378
The float16 class holds a 16-bit `uint16_t` data internally.

0 commit comments

Comments
 (0)