Skip to content

Commit f2bbbb2

Browse files
committed
fix arithmetic operator
1 parent 18d616e commit f2bbbb2

File tree

1 file changed

+68
-33
lines changed

1 file changed

+68
-33
lines changed

paddle/fluid/platform/float16.h

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -484,72 +484,107 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
484484
#endif // PADDLE_CUDA_FP16
485485

486486
// Arithmetic operators for float16 on GPU
487-
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
488-
DEVICE inline float16 operator+(const float16& a, const float16& b) {
487+
#if defined(PADDLE_CUDA_FP16)
488+
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
489+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
489490
return float16(__hadd(half(a), half(b)));
491+
#else
492+
return float16(float(a) + float(b));
490493
}
491494

492-
DEVICE inline float16 operator-(const float16& a, const float16& b) {
495+
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
496+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
493497
return float16(__hsub(half(a), half(b)));
498+
#else
499+
return float16(float(a) - float(b));
494500
}
495501

496-
DEVICE inline float16 operator*(const float16& a, const float16& b) {
502+
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
503+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
497504
return float16(__hmul(half(a), half(b)));
505+
#else
506+
return float16(float(a) * float(b));
498507
}
499508

500-
DEVICE inline float16 operator/(const float16& a, const float16& b) {
501-
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
509+
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
510+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
511+
// TODO(kexinzhao): check which cuda version starts to support __hdiv
502512
float num = __half2float(half(a));
503513
float denom = __half2float(half(b));
504514
return float16(num / denom);
515+
#else
516+
return float16(float(a) / float(b));
505517
}
506518

507-
DEVICE inline float16 operator-(const float16& a) {
519+
HOSTDEVICE inline float16 operator-(const float16& a) {
520+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
508521
return float16(__hneg(half(a)));
522+
#else
523+
float16 res;
524+
res.x = a.x ^ 0x8000;
525+
return res;
509526
}
510527

511-
DEVICE inline float16& operator+=(float16& a, const float16& b) {
528+
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
512529
a = a + b;
513530
return a;
514531
}
515532

516-
DEVICE inline float16& operator-=(float16& a, const float16& b) {
533+
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
517534
a = a - b;
518535
return a;
519536
}
520537

521-
DEVICE inline float16& operator*=(float16& a, const float16& b) {
538+
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
522539
a = a * b;
523540
return a;
524541
}
525542

526-
DEVICE inline float16& operator/=(float16& a, const float16& b) {
543+
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
527544
a = a / b;
528545
return a;
529546
}
530547

531-
DEVICE inline bool operator==(const float16& a, const float16& b) {
548+
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
549+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
532550
return __heq(half(a), half(b));
551+
#else
552+
return float(a) == float(b);
533553
}
534554

535-
DEVICE inline bool operator!=(const float16& a, const float16& b) {
555+
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
556+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
536557
return __hne(half(a), half(b));
558+
#else
559+
return float(a) != float(b);
537560
}
538561

539-
DEVICE inline bool operator<(const float16& a, const float16& b) {
562+
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
563+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
540564
return __hlt(half(a), half(b));
565+
#else
566+
return float(a) < float(b);
541567
}
542568

543-
DEVICE inline bool operator<=(const float16& a, const float16& b) {
569+
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
570+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
544571
return __hle(half(a), half(b));
572+
#else
573+
return float(a) <= float(b);
545574
}
546575

547-
DEVICE inline bool operator>(const float16& a, const float16& b) {
576+
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
577+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
548578
return __hgt(half(a), half(b));
579+
#else
580+
return float(a) > float(b);
549581
}
550582

551-
DEVICE inline bool operator>=(const float16& a, const float16& b) {
583+
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
584+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
552585
return __hge(half(a), half(b));
586+
#else
587+
return float(a) >= float(b);
553588
}
554589

555590
// Arithmetic operators for float16 on ARMv8.2-A CPU
@@ -737,71 +772,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
737772
return (res & 0xffff) != 0;
738773
}
739774

740-
// Arithmetic operators for float16, software emulated on other CPU/GPU
775+
// Arithmetic operators for float16, software emulated on other CPU
741776
#else
742-
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
777+
HOST inline float16 operator+(const float16& a, const float16& b) {
743778
return float16(float(a) + float(b));
744779
}
745780

746-
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
781+
HOST inline float16 operator-(const float16& a, const float16& b) {
747782
return float16(float(a) - float(b));
748783
}
749784

750-
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
785+
HOST inline float16 operator*(const float16& a, const float16& b) {
751786
return float16(float(a) * float(b));
752787
}
753788

754-
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
789+
HOST inline float16 operator/(const float16& a, const float16& b) {
755790
return float16(float(a) / float(b));
756791
}
757792

758-
HOSTDEVICE inline float16 operator-(const float16& a) {
793+
HOST inline float16 operator-(const float16& a) {
759794
float16 res;
760795
res.x = a.x ^ 0x8000;
761796
return res;
762797
}
763798

764-
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
799+
HOST inline float16& operator+=(float16& a, const float16& b) {
765800
a = float16(float(a) + float(b));
766801
return a;
767802
}
768803

769-
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
804+
HOST inline float16& operator-=(float16& a, const float16& b) {
770805
a = float16(float(a) - float(b));
771806
return a;
772807
}
773808

774-
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
809+
HOST inline float16& operator*=(float16& a, const float16& b) {
775810
a = float16(float(a) * float(b));
776811
return a;
777812
}
778813

779-
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
814+
HOST inline float16& operator/=(float16& a, const float16& b) {
780815
a = float16(float(a) / float(b));
781816
return a;
782817
}
783818

784-
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
819+
HOST inline bool operator==(const float16& a, const float16& b) {
785820
return float(a) == float(b);
786821
}
787822

788-
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
823+
HOST inline bool operator!=(const float16& a, const float16& b) {
789824
return float(a) != float(b);
790825
}
791826

792-
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
827+
HOST inline bool operator<(const float16& a, const float16& b) {
793828
return float(a) < float(b);
794829
}
795830

796-
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
831+
HOST inline bool operator<=(const float16& a, const float16& b) {
797832
return float(a) <= float(b);
798833
}
799834

800-
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
835+
HOST inline bool operator>(const float16& a, const float16& b) {
801836
return float(a) > float(b);
802837
}
803838

804-
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
839+
HOST inline bool operator>=(const float16& a, const float16& b) {
805840
return float(a) >= float(b);
806841
}
807842
#endif

0 commit comments

Comments
 (0)