@@ -484,72 +484,107 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
484
484
#endif // PADDLE_CUDA_FP16
485
485
486
486
// Arithmetic operators for float16 on GPU
487
- #if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
488
- DEVICE inline float16 operator +(const float16& a, const float16& b) {
487
+ #if defined(PADDLE_CUDA_FP16)
488
+ HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
489
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
489
490
return float16 (__hadd (half (a), half (b)));
491
+ #else
492
+ return float16 (float (a) + float (b));
490
493
}
491
494
492
- DEVICE inline float16 operator -(const float16& a, const float16& b) {
495
+ HOSTDEVICE inline float16 operator -(const float16& a, const float16& b) {
496
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
493
497
return float16 (__hsub (half (a), half (b)));
498
+ #else
499
+ return float16 (float (a) - float (b));
494
500
}
495
501
496
- DEVICE inline float16 operator *(const float16& a, const float16& b) {
502
+ HOSTDEVICE inline float16 operator *(const float16& a, const float16& b) {
503
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
497
504
return float16 (__hmul (half (a), half (b)));
505
+ #else
506
+ return float16 (float (a) * float (b));
498
507
}
499
508
500
- DEVICE inline float16 operator /(const float16& a, const float16& b) {
501
- // TODO(kexinzhao): check the cuda version that starts to support __hdiv
509
+ HOSTDEVICE inline float16 operator /(const float16& a, const float16& b) {
510
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
511
+ // TODO(kexinzhao): check which cuda version starts to support __hdiv
502
512
float num = __half2float (half (a));
503
513
float denom = __half2float (half (b));
504
514
return float16 (num / denom);
515
+ #else
516
+ return float16 (float (a) / float (b));
505
517
}
506
518
507
- DEVICE inline float16 operator -(const float16& a) {
519
+ HOSTDEVICE inline float16 operator -(const float16& a) {
520
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
508
521
return float16 (__hneg (half (a)));
522
+ #else
523
+ float16 res;
524
+ res.x = a.x ^ 0x8000 ;
525
+ return res;
509
526
}
510
527
511
- DEVICE inline float16& operator +=(float16& a, const float16& b) {
528
+ HOSTDEVICE inline float16& operator +=(float16& a, const float16& b) {
512
529
a = a + b;
513
530
return a;
514
531
}
515
532
516
- DEVICE inline float16& operator -=(float16& a, const float16& b) {
533
+ HOSTDEVICE inline float16& operator -=(float16& a, const float16& b) {
517
534
a = a - b;
518
535
return a;
519
536
}
520
537
521
- DEVICE inline float16& operator *=(float16& a, const float16& b) {
538
+ HOSTDEVICE inline float16& operator *=(float16& a, const float16& b) {
522
539
a = a * b;
523
540
return a;
524
541
}
525
542
526
- DEVICE inline float16& operator /=(float16& a, const float16& b) {
543
+ HOSTDEVICE inline float16& operator /=(float16& a, const float16& b) {
527
544
a = a / b;
528
545
return a;
529
546
}
530
547
531
- DEVICE inline bool operator ==(const float16& a, const float16& b) {
548
+ HOSTDEVICE inline bool operator ==(const float16& a, const float16& b) {
549
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
532
550
return __heq (half (a), half (b));
551
+ #else
552
+ return float (a) == float (b);
533
553
}
534
554
535
- DEVICE inline bool operator !=(const float16& a, const float16& b) {
555
+ HOSTDEVICE inline bool operator !=(const float16& a, const float16& b) {
556
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
536
557
return __hne (half (a), half (b));
558
+ #else
559
+ return float (a) != float (b);
537
560
}
538
561
539
- DEVICE inline bool operator <(const float16& a, const float16& b) {
562
+ HOSTDEVICE inline bool operator <(const float16& a, const float16& b) {
563
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
540
564
return __hlt (half (a), half (b));
565
+ #else
566
+ return float (a) < float (b);
541
567
}
542
568
543
- DEVICE inline bool operator <=(const float16& a, const float16& b) {
569
+ HOSTDEVICE inline bool operator <=(const float16& a, const float16& b) {
570
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
544
571
return __hle (half (a), half (b));
572
+ #else
573
+ return float (a) <= float (b);
545
574
}
546
575
547
- DEVICE inline bool operator >(const float16& a, const float16& b) {
576
+ HOSTDEVICE inline bool operator >(const float16& a, const float16& b) {
577
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
548
578
return __hgt (half (a), half (b));
579
+ #else
580
+ return float (a) > float (b);
549
581
}
550
582
551
- DEVICE inline bool operator >=(const float16& a, const float16& b) {
583
+ HOSTDEVICE inline bool operator >=(const float16& a, const float16& b) {
584
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
552
585
return __hge (half (a), half (b));
586
+ #else
587
+ return float (a) >= float (b);
553
588
}
554
589
555
590
// Arithmetic operators for float16 on ARMv8.2-A CPU
@@ -737,71 +772,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
737
772
return (res & 0xffff ) != 0 ;
738
773
}
739
774
740
- // Arithmetic operators for float16, software emulated on other CPU/GPU
775
+ // Arithmetic operators for float16, software emulated on other CPU
741
776
#else
742
- HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
777
+ HOST inline float16 operator +(const float16& a, const float16& b) {
743
778
return float16 (float (a) + float (b));
744
779
}
745
780
746
- HOSTDEVICE inline float16 operator -(const float16& a, const float16& b) {
781
+ HOST inline float16 operator -(const float16& a, const float16& b) {
747
782
return float16 (float (a) - float (b));
748
783
}
749
784
750
- HOSTDEVICE inline float16 operator *(const float16& a, const float16& b) {
785
+ HOST inline float16 operator *(const float16& a, const float16& b) {
751
786
return float16 (float (a) * float (b));
752
787
}
753
788
754
- HOSTDEVICE inline float16 operator /(const float16& a, const float16& b) {
789
+ HOST inline float16 operator /(const float16& a, const float16& b) {
755
790
return float16 (float (a) / float (b));
756
791
}
757
792
758
- HOSTDEVICE inline float16 operator -(const float16& a) {
793
+ HOST inline float16 operator -(const float16& a) {
759
794
float16 res;
760
795
res.x = a.x ^ 0x8000 ;
761
796
return res;
762
797
}
763
798
764
- HOSTDEVICE inline float16& operator +=(float16& a, const float16& b) {
799
+ HOST inline float16& operator +=(float16& a, const float16& b) {
765
800
a = float16 (float (a) + float (b));
766
801
return a;
767
802
}
768
803
769
- HOSTDEVICE inline float16& operator -=(float16& a, const float16& b) {
804
+ HOST inline float16& operator -=(float16& a, const float16& b) {
770
805
a = float16 (float (a) - float (b));
771
806
return a;
772
807
}
773
808
774
- HOSTDEVICE inline float16& operator *=(float16& a, const float16& b) {
809
+ HOST inline float16& operator *=(float16& a, const float16& b) {
775
810
a = float16 (float (a) * float (b));
776
811
return a;
777
812
}
778
813
779
- HOSTDEVICE inline float16& operator /=(float16& a, const float16& b) {
814
+ HOST inline float16& operator /=(float16& a, const float16& b) {
780
815
a = float16 (float (a) / float (b));
781
816
return a;
782
817
}
783
818
784
- HOSTDEVICE inline bool operator ==(const float16& a, const float16& b) {
819
+ HOST inline bool operator ==(const float16& a, const float16& b) {
785
820
return float (a) == float (b);
786
821
}
787
822
788
- HOSTDEVICE inline bool operator !=(const float16& a, const float16& b) {
823
+ HOST inline bool operator !=(const float16& a, const float16& b) {
789
824
return float (a) != float (b);
790
825
}
791
826
792
- HOSTDEVICE inline bool operator <(const float16& a, const float16& b) {
827
+ HOST inline bool operator <(const float16& a, const float16& b) {
793
828
return float (a) < float (b);
794
829
}
795
830
796
- HOSTDEVICE inline bool operator <=(const float16& a, const float16& b) {
831
+ HOST inline bool operator <=(const float16& a, const float16& b) {
797
832
return float (a) <= float (b);
798
833
}
799
834
800
- HOSTDEVICE inline bool operator >(const float16& a, const float16& b) {
835
+ HOST inline bool operator >(const float16& a, const float16& b) {
801
836
return float (a) > float (b);
802
837
}
803
838
804
- HOSTDEVICE inline bool operator >=(const float16& a, const float16& b) {
839
+ HOST inline bool operator >=(const float16& a, const float16& b) {
805
840
return float (a) >= float (b);
806
841
}
807
842
#endif
0 commit comments