@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
483
483
484
484
#endif // PADDLE_CUDA_FP16
485
485
486
- // Arithmetic operators on ARMv8.2-A CPU
487
- #if defined(PADDLE_WITH_NATIVE_FP16)
486
+ // Arithmetic operators for float16 on GPU
487
+ #if defined(PADDLE_CUDA_FP16)
488
+ HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
489
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
490
+ return float16 (__hadd (half (a), half (b)));
491
+ #else
492
+ return float16 (float (a) + float (b));
493
+ #endif
494
+ }
495
+
496
+ HOSTDEVICE inline float16 operator -(const float16& a, const float16& b) {
497
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
498
+ return float16 (__hsub (half (a), half (b)));
499
+ #else
500
+ return float16 (float (a) - float (b));
501
+ #endif
502
+ }
503
+
504
+ HOSTDEVICE inline float16 operator *(const float16& a, const float16& b) {
505
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
506
+ return float16 (__hmul (half (a), half (b)));
507
+ #else
508
+ return float16 (float (a) * float (b));
509
+ #endif
510
+ }
511
+
512
+ HOSTDEVICE inline float16 operator /(const float16& a, const float16& b) {
513
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
514
+ // TODO(kexinzhao): check which cuda version starts to support __hdiv
515
+ float num = __half2float (half (a));
516
+ float denom = __half2float (half (b));
517
+ return float16 (num / denom);
518
+ #else
519
+ return float16 (float (a) / float (b));
520
+ #endif
521
+ }
522
+
523
+ HOSTDEVICE inline float16 operator -(const float16& a) {
524
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
525
+ return float16 (__hneg (half (a)));
526
+ #else
527
+ float16 res;
528
+ res.x = a.x ^ 0x8000 ;
529
+ return res;
530
+ #endif
531
+ }
532
+
533
+ HOSTDEVICE inline float16& operator +=(float16& a, const float16& b) {
534
+ a = a + b;
535
+ return a;
536
+ }
537
+
538
+ HOSTDEVICE inline float16& operator -=(float16& a, const float16& b) {
539
+ a = a - b;
540
+ return a;
541
+ }
542
+
543
+ HOSTDEVICE inline float16& operator *=(float16& a, const float16& b) {
544
+ a = a * b;
545
+ return a;
546
+ }
547
+
548
+ HOSTDEVICE inline float16& operator /=(float16& a, const float16& b) {
549
+ a = a / b;
550
+ return a;
551
+ }
552
+
553
+ HOSTDEVICE inline bool operator ==(const float16& a, const float16& b) {
554
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
555
+ return __heq (half (a), half (b));
556
+ #else
557
+ return float (a) == float (b);
558
+ #endif
559
+ }
560
+
561
+ HOSTDEVICE inline bool operator !=(const float16& a, const float16& b) {
562
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
563
+ return __hne (half (a), half (b));
564
+ #else
565
+ return float (a) != float (b);
566
+ #endif
567
+ }
568
+
569
+ HOSTDEVICE inline bool operator <(const float16& a, const float16& b) {
570
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
571
+ return __hlt (half (a), half (b));
572
+ #else
573
+ return float (a) < float (b);
574
+ #endif
575
+ }
576
+
577
+ HOSTDEVICE inline bool operator <=(const float16& a, const float16& b) {
578
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
579
+ return __hle (half (a), half (b));
580
+ #else
581
+ return float (a) <= float (b);
582
+ #endif
583
+ }
584
+
585
+ HOSTDEVICE inline bool operator >(const float16& a, const float16& b) {
586
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
587
+ return __hgt (half (a), half (b));
588
+ #else
589
+ return float (a) > float (b);
590
+ #endif
591
+ }
592
+
593
+ HOSTDEVICE inline bool operator >=(const float16& a, const float16& b) {
594
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
595
+ return __hge (half (a), half (b));
596
+ #else
597
+ return float (a) >= float (b);
598
+ #endif
599
+ }
600
+
601
+ // Arithmetic operators for float16 on ARMv8.2-A CPU
602
+ #elif defined(PADDLE_WITH_NATIVE_FP16)
488
603
HOST inline float16 operator +(const float16& a, const float16& b) {
489
604
float16 res;
490
605
asm volatile (
@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
668
783
return (res & 0xffff ) != 0 ;
669
784
}
670
785
671
- // Arithmetic operators, software emulated on other CPU
786
+ // Arithmetic operators for float16 , software emulated on other CPU
672
787
#else
673
- HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
788
+ HOST inline float16 operator +(const float16& a, const float16& b) {
674
789
return float16 (float (a) + float (b));
675
790
}
676
791
677
- HOSTDEVICE inline float16 operator -(const float16& a, const float16& b) {
792
+ HOST inline float16 operator -(const float16& a, const float16& b) {
678
793
return float16 (float (a) - float (b));
679
794
}
680
795
681
- HOSTDEVICE inline float16 operator *(const float16& a, const float16& b) {
796
+ HOST inline float16 operator *(const float16& a, const float16& b) {
682
797
return float16 (float (a) * float (b));
683
798
}
684
799
685
- HOSTDEVICE inline float16 operator /(const float16& a, const float16& b) {
800
+ HOST inline float16 operator /(const float16& a, const float16& b) {
686
801
return float16 (float (a) / float (b));
687
802
}
688
803
689
- HOSTDEVICE inline float16 operator -(const float16& a) {
804
+ HOST inline float16 operator -(const float16& a) {
690
805
float16 res;
691
806
res.x = a.x ^ 0x8000 ;
692
807
return res;
693
808
}
694
809
695
- HOSTDEVICE inline float16& operator +=(float16& a, const float16& b) {
810
+ HOST inline float16& operator +=(float16& a, const float16& b) {
696
811
a = float16 (float (a) + float (b));
697
812
return a;
698
813
}
699
814
700
- HOSTDEVICE inline float16& operator -=(float16& a, const float16& b) {
815
+ HOST inline float16& operator -=(float16& a, const float16& b) {
701
816
a = float16 (float (a) - float (b));
702
817
return a;
703
818
}
704
819
705
- HOSTDEVICE inline float16& operator *=(float16& a, const float16& b) {
820
+ HOST inline float16& operator *=(float16& a, const float16& b) {
706
821
a = float16 (float (a) * float (b));
707
822
return a;
708
823
}
709
824
710
- HOSTDEVICE inline float16& operator /=(float16& a, const float16& b) {
825
+ HOST inline float16& operator /=(float16& a, const float16& b) {
711
826
a = float16 (float (a) / float (b));
712
827
return a;
713
828
}
714
829
715
- HOSTDEVICE inline bool operator ==(const float16& a, const float16& b) {
830
+ HOST inline bool operator ==(const float16& a, const float16& b) {
716
831
return float (a) == float (b);
717
832
}
718
833
719
- HOSTDEVICE inline bool operator !=(const float16& a, const float16& b) {
834
+ HOST inline bool operator !=(const float16& a, const float16& b) {
720
835
return float (a) != float (b);
721
836
}
722
837
723
- HOSTDEVICE inline bool operator <(const float16& a, const float16& b) {
838
+ HOST inline bool operator <(const float16& a, const float16& b) {
724
839
return float (a) < float (b);
725
840
}
726
841
727
- HOSTDEVICE inline bool operator <=(const float16& a, const float16& b) {
842
+ HOST inline bool operator <=(const float16& a, const float16& b) {
728
843
return float (a) <= float (b);
729
844
}
730
845
731
- HOSTDEVICE inline bool operator >(const float16& a, const float16& b) {
846
+ HOST inline bool operator >(const float16& a, const float16& b) {
732
847
return float (a) > float (b);
733
848
}
734
849
735
- HOSTDEVICE inline bool operator >=(const float16& a, const float16& b) {
850
+ HOST inline bool operator >=(const float16& a, const float16& b) {
736
851
return float (a) >= float (b);
737
852
}
738
853
#endif
0 commit comments