Skip to content

Commit c1e9b1e

Browse files
authored
Merge pull request #9231 from kexinzhao/elementwise_add_fp16
Add float16 support to Elementwise Add op
2 parents d126933 + d307b5e commit c1e9b1e

File tree

3 files changed

+230
-146
lines changed

3 files changed

+230
-146
lines changed

paddle/fluid/operators/elementwise_add_op.cu

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@ limitations under the License. */
1414

1515
#define EIGEN_USE_GPU
1616
#include "paddle/fluid/operators/elementwise_add_op.h"
17+
#include "paddle/fluid/platform/float16.h"
1718

1819
namespace ops = paddle::operators;
20+
namespace plat = paddle::platform;
1921

2022
REGISTER_OP_CUDA_KERNEL(
21-
elementwise_add,
22-
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>,
23-
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, double>,
24-
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int>,
25-
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int64_t>);
23+
elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
24+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
25+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
26+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
27+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
2628
REGISTER_OP_CUDA_KERNEL(
2729
elementwise_add_grad,
28-
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, float>,
29-
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, double>,
30-
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, int>,
31-
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext,
32-
int64_t>);
30+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
31+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
32+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
33+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>);

paddle/fluid/platform/float16.h

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
600600

601601
// Arithmetic operators for float16 on ARMv8.2-A CPU
602602
#elif defined(PADDLE_WITH_NATIVE_FP16)
603-
HOST inline float16 operator+(const float16& a, const float16& b) {
603+
inline float16 operator+(const float16& a, const float16& b) {
604604
float16 res;
605605
asm volatile(
606606
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -616,7 +616,7 @@ HOST inline float16 operator+(const float16& a, const float16& b) {
616616
return res;
617617
}
618618

619-
HOST inline float16 operator-(const float16& a, const float16& b) {
619+
inline float16 operator-(const float16& a, const float16& b) {
620620
float16 res;
621621
asm volatile(
622622
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -632,7 +632,7 @@ HOST inline float16 operator-(const float16& a, const float16& b) {
632632
return res;
633633
}
634634

635-
HOST inline float16 operator*(const float16& a, const float16& b) {
635+
inline float16 operator*(const float16& a, const float16& b) {
636636
float16 res;
637637
asm volatile(
638638
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -648,7 +648,7 @@ HOST inline float16 operator*(const float16& a, const float16& b) {
648648
return res;
649649
}
650650

651-
HOST inline float16 operator/(const float16& a, const float16& b) {
651+
inline float16 operator/(const float16& a, const float16& b) {
652652
float16 res;
653653
asm volatile(
654654
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -664,7 +664,7 @@ HOST inline float16 operator/(const float16& a, const float16& b) {
664664
return res;
665665
}
666666

667-
HOST inline float16 operator-(const float16& a) {
667+
inline float16 operator-(const float16& a) {
668668
float16 res;
669669
asm volatile(
670670
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -679,27 +679,27 @@ HOST inline float16 operator-(const float16& a) {
679679
return res;
680680
}
681681

682-
HOST inline float16& operator+=(float16& a, const float16& b) {
682+
inline float16& operator+=(float16& a, const float16& b) {
683683
a = a + b;
684684
return a;
685685
}
686686

687-
HOST inline float16& operator-=(float16& a, const float16& b) {
687+
inline float16& operator-=(float16& a, const float16& b) {
688688
a = a - b;
689689
return a;
690690
}
691691

692-
HOST inline float16& operator*=(float16& a, const float16& b) {
692+
inline float16& operator*=(float16& a, const float16& b) {
693693
a = a * b;
694694
return a;
695695
}
696696

697-
HOST inline float16& operator/=(float16& a, const float16& b) {
697+
inline float16& operator/=(float16& a, const float16& b) {
698698
a = a / b;
699699
return a;
700700
}
701701

702-
HOST inline bool operator==(const float16& a, const float16& b) {
702+
inline bool operator==(const float16& a, const float16& b) {
703703
uint16_t res;
704704
asm volatile(
705705
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -715,11 +715,9 @@ HOST inline bool operator==(const float16& a, const float16& b) {
715715
return (res & 0xffff) != 0;
716716
}
717717

718-
HOST inline bool operator!=(const float16& a, const float16& b) {
719-
return !(a == b);
720-
}
718+
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
721719

722-
HOST inline bool operator<(const float16& a, const float16& b) {
720+
inline bool operator<(const float16& a, const float16& b) {
723721
uint16_t res;
724722
asm volatile(
725723
"ld1 {v1.h}[0], [%[a_ptr]]\n"
@@ -735,7 +733,7 @@ HOST inline bool operator<(const float16& a, const float16& b) {
735733
return (res & 0xffff) != 0;
736734
}
737735

738-
HOST inline bool operator<=(const float16& a, const float16& b) {
736+
inline bool operator<=(const float16& a, const float16& b) {
739737
uint16_t res;
740738
asm volatile(
741739
"ld1 {v1.h}[0], [%[a_ptr]]\n"
@@ -751,7 +749,7 @@ HOST inline bool operator<=(const float16& a, const float16& b) {
751749
return (res & 0xffff) != 0;
752750
}
753751

754-
HOST inline bool operator>(const float16& a, const float16& b) {
752+
inline bool operator>(const float16& a, const float16& b) {
755753
uint16_t res;
756754
asm volatile(
757755
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -767,7 +765,7 @@ HOST inline bool operator>(const float16& a, const float16& b) {
767765
return (res & 0xffff) != 0;
768766
}
769767

770-
HOST inline bool operator>=(const float16& a, const float16& b) {
768+
inline bool operator>=(const float16& a, const float16& b) {
771769
uint16_t res;
772770
asm volatile(
773771
"ld1 {v0.h}[0], [%[a_ptr]]\n"
@@ -785,69 +783,69 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
785783

786784
// Arithmetic operators for float16, software emulated on other CPU
787785
#else
788-
HOST inline float16 operator+(const float16& a, const float16& b) {
786+
inline float16 operator+(const float16& a, const float16& b) {
789787
return float16(float(a) + float(b));
790788
}
791789

792-
HOST inline float16 operator-(const float16& a, const float16& b) {
790+
inline float16 operator-(const float16& a, const float16& b) {
793791
return float16(float(a) - float(b));
794792
}
795793

796-
HOST inline float16 operator*(const float16& a, const float16& b) {
794+
inline float16 operator*(const float16& a, const float16& b) {
797795
return float16(float(a) * float(b));
798796
}
799797

800-
HOST inline float16 operator/(const float16& a, const float16& b) {
798+
inline float16 operator/(const float16& a, const float16& b) {
801799
return float16(float(a) / float(b));
802800
}
803801

804-
HOST inline float16 operator-(const float16& a) {
802+
inline float16 operator-(const float16& a) {
805803
float16 res;
806804
res.x = a.x ^ 0x8000;
807805
return res;
808806
}
809807

810-
HOST inline float16& operator+=(float16& a, const float16& b) {
808+
inline float16& operator+=(float16& a, const float16& b) {
811809
a = float16(float(a) + float(b));
812810
return a;
813811
}
814812

815-
HOST inline float16& operator-=(float16& a, const float16& b) {
813+
inline float16& operator-=(float16& a, const float16& b) {
816814
a = float16(float(a) - float(b));
817815
return a;
818816
}
819817

820-
HOST inline float16& operator*=(float16& a, const float16& b) {
818+
inline float16& operator*=(float16& a, const float16& b) {
821819
a = float16(float(a) * float(b));
822820
return a;
823821
}
824822

825-
HOST inline float16& operator/=(float16& a, const float16& b) {
823+
inline float16& operator/=(float16& a, const float16& b) {
826824
a = float16(float(a) / float(b));
827825
return a;
828826
}
829827

830-
HOST inline bool operator==(const float16& a, const float16& b) {
828+
inline bool operator==(const float16& a, const float16& b) {
831829
return float(a) == float(b);
832830
}
833831

834-
HOST inline bool operator!=(const float16& a, const float16& b) {
832+
inline bool operator!=(const float16& a, const float16& b) {
835833
return float(a) != float(b);
836834
}
837835

838-
HOST inline bool operator<(const float16& a, const float16& b) {
836+
inline bool operator<(const float16& a, const float16& b) {
839837
return float(a) < float(b);
840838
}
841839

842-
HOST inline bool operator<=(const float16& a, const float16& b) {
840+
inline bool operator<=(const float16& a, const float16& b) {
843841
return float(a) <= float(b);
844842
}
845843

846-
HOST inline bool operator>(const float16& a, const float16& b) {
844+
inline bool operator>(const float16& a, const float16& b) {
847845
return float(a) > float(b);
848846
}
849847

850-
HOST inline bool operator>=(const float16& a, const float16& b) {
848+
inline bool operator>=(const float16& a, const float16& b) {
851849
return float(a) >= float(b);
852850
}
853851
#endif

0 commit comments

Comments
 (0)