@@ -483,8 +483,77 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
483
483
484
484
#endif // PADDLE_CUDA_FP16
485
485
486
- // Arithmetic operators on ARMv8.2-A CPU
487
- #if defined(PADDLE_WITH_NATIVE_FP16)
486
+ // Arithmetic operators for float16 on GPU
487
+ #if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
488
+ DEVICE inline float16 operator +(const float16& a, const float16& b) {
489
+ return float16 (__hadd (half (a), half (b)));
490
+ }
491
+
492
+ DEVICE inline float16 operator -(const float16& a, const float16& b) {
493
+ return float16 (__hsub (half (a), half (b)));
494
+ }
495
+
496
+ DEVICE inline float16 operator *(const float16& a, const float16& b) {
497
+ return float16 (__hmul (half (a), half (b)));
498
+ }
499
+
500
+ DEVICE inline float16 operator /(const float16& a, const float16& b) {
501
+ // TODO(kexinzhao): check the cuda version that starts to support __hdiv
502
+ float num = __half2float (half (a));
503
+ float denom = __half2float (half (b));
504
+ return float16 (num / denom);
505
+ }
506
+
507
+ DEVICE inline float16 operator -(const float16& a) {
508
+ return float16 (__hneg (half (a)));
509
+ }
510
+
511
+ DEVICE inline float16& operator +=(float16& a, const float16& b) {
512
+ a = a + b;
513
+ return a;
514
+ }
515
+
516
+ DEVICE inline float16& operator -=(float16& a, const float16& b) {
517
+ a = a - b;
518
+ return a;
519
+ }
520
+
521
+ DEVICE inline float16& operator *=(float16& a, const float16& b) {
522
+ a = a * b;
523
+ return a;
524
+ }
525
+
526
+ DEVICE inline float16& operator /=(float16& a, const float16& b) {
527
+ a = a / b;
528
+ return a;
529
+ }
530
+
531
+ DEVICE inline bool operator ==(const float16& a, const float16& b) {
532
+ return __heq (half (a), half (b));
533
+ }
534
+
535
+ DEVICE inline bool operator !=(const float16& a, const float16& b) {
536
+ return __hne (half (a), half (b));
537
+ }
538
+
539
+ DEVICE inline bool operator <(const float16& a, const float16& b) {
540
+ return __hlt (half (a), half (b));
541
+ }
542
+
543
+ DEVICE inline bool operator <=(const float16& a, const float16& b) {
544
+ return __hle (half (a), half (b));
545
+ }
546
+
547
+ DEVICE inline bool operator >(const float16& a, const float16& b) {
548
+ return __hgt (half (a), half (b));
549
+ }
550
+
551
+ DEVICE inline bool operator >=(const float16& a, const float16& b) {
552
+ return __hge (half (a), half (b));
553
+ }
554
+
555
+ // Arithmetic operators for float16 on ARMv8.2-A CPU
556
+ #elif defined(PADDLE_WITH_NATIVE_FP16)
488
557
HOST inline float16 operator +(const float16& a, const float16& b) {
489
558
float16 res;
490
559
asm volatile (
@@ -668,7 +737,7 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
668
737
return (res & 0xffff ) != 0 ;
669
738
}
670
739
671
- // Arithmetic operators, software emulated on other CPU
740
+ // Arithmetic operators for float16 , software emulated on other CPU/GPU
672
741
#else
673
742
HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
674
743
return float16 (float (a) + float (b));
0 commit comments