Skip to content

Commit 18d616e

Browse files
committed
add float16 arithmetic operators on new GPU
1 parent d03dbb9 commit 18d616e

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

paddle/fluid/platform/float16.h

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,77 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
483483

484484
#endif // PADDLE_CUDA_FP16
485485

486-
// Arithmetic operators on ARMv8.2-A CPU
487-
#if defined(PADDLE_WITH_NATIVE_FP16)
486+
// Arithmetic operators for float16 on GPU
487+
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
488+
DEVICE inline float16 operator+(const float16& a, const float16& b) {
489+
return float16(__hadd(half(a), half(b)));
490+
}
491+
492+
DEVICE inline float16 operator-(const float16& a, const float16& b) {
493+
return float16(__hsub(half(a), half(b)));
494+
}
495+
496+
DEVICE inline float16 operator*(const float16& a, const float16& b) {
497+
return float16(__hmul(half(a), half(b)));
498+
}
499+
500+
DEVICE inline float16 operator/(const float16& a, const float16& b) {
501+
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
502+
float num = __half2float(half(a));
503+
float denom = __half2float(half(b));
504+
return float16(num / denom);
505+
}
506+
507+
DEVICE inline float16 operator-(const float16& a) {
508+
return float16(__hneg(half(a)));
509+
}
510+
511+
DEVICE inline float16& operator+=(float16& a, const float16& b) {
512+
a = a + b;
513+
return a;
514+
}
515+
516+
DEVICE inline float16& operator-=(float16& a, const float16& b) {
517+
a = a - b;
518+
return a;
519+
}
520+
521+
DEVICE inline float16& operator*=(float16& a, const float16& b) {
522+
a = a * b;
523+
return a;
524+
}
525+
526+
DEVICE inline float16& operator/=(float16& a, const float16& b) {
527+
a = a / b;
528+
return a;
529+
}
530+
531+
DEVICE inline bool operator==(const float16& a, const float16& b) {
532+
return __heq(half(a), half(b));
533+
}
534+
535+
DEVICE inline bool operator!=(const float16& a, const float16& b) {
536+
return __hne(half(a), half(b));
537+
}
538+
539+
DEVICE inline bool operator<(const float16& a, const float16& b) {
540+
return __hlt(half(a), half(b));
541+
}
542+
543+
DEVICE inline bool operator<=(const float16& a, const float16& b) {
544+
return __hle(half(a), half(b));
545+
}
546+
547+
DEVICE inline bool operator>(const float16& a, const float16& b) {
548+
return __hgt(half(a), half(b));
549+
}
550+
551+
DEVICE inline bool operator>=(const float16& a, const float16& b) {
552+
return __hge(half(a), half(b));
553+
}
554+
555+
// Arithmetic operators for float16 on ARMv8.2-A CPU
556+
#elif defined(PADDLE_WITH_NATIVE_FP16)
488557
HOST inline float16 operator+(const float16& a, const float16& b) {
489558
float16 res;
490559
asm volatile(
@@ -668,7 +737,7 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
668737
return (res & 0xffff) != 0;
669738
}
670739

671-
// Arithmetic operators, software emulated on other CPU
740+
// Arithmetic operators for float16, software emulated on other CPU/GPU
672741
#else
673742
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
674743
return float16(float(a) + float(b));

python/paddle/fluid/tests/unittests/test_dropout_op.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ def test_check_output(self):
8686
class TestFP16DropoutOp1(OpTest):
8787
def setUp(self):
8888
x = np.random.random((32, 64)).astype("float16")
89+
prob = 0.35
90+
out = x * (1.0 - prob)
91+
8992
self.op_type = "dropout"
9093
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
91-
self.attrs = {'dropout_prob': 0.35, 'fix_seed': True, 'is_test': True}
92-
self.outputs = {'Out': x * (1.0 - self.attrs['dropout_prob'])}
94+
self.attrs = {'dropout_prob': prob, 'fix_seed': True, 'is_test': True}
95+
self.outputs = {'Out': out}
9396

9497
def test_check_output(self):
9598
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
@@ -99,10 +102,13 @@ def test_check_output(self):
99102
class TestFP16DropoutOp2(OpTest):
100103
def setUp(self):
101104
x = np.random.random((32, 64, 3)).astype("float16")
105+
prob = 0.75
106+
out = x * (1.0 - prob)
107+
102108
self.op_type = "dropout"
103109
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
104-
self.attrs = {'dropout_prob': 0.75, 'is_test': True}
105-
self.outputs = {'Out': x * (1.0 - self.attrs['dropout_prob'])}
110+
self.attrs = {'dropout_prob': prob, 'is_test': True}
111+
self.outputs = {'Out': out}
106112

107113
def test_check_output(self):
108114
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):

0 commit comments

Comments
 (0)