@@ -14,7 +14,7 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
- #include < cstdint >
17
+ #include < stdint.h >
18
18
19
19
#ifdef PADDLE_WITH_CUDA
20
20
#include < cuda.h>
@@ -71,6 +71,7 @@ struct PADDLE_ALIGN(2) float16 {
71
71
public:
72
72
uint16_t x;
73
73
74
+ // Constructors
74
75
HOSTDEVICE inline float16 () : x (0 ) {}
75
76
76
77
HOSTDEVICE inline float16 (const float16& h) : x (h.x ) {}
@@ -89,8 +90,7 @@ struct PADDLE_ALIGN(2) float16 {
89
90
90
91
#ifdef PADDLE_WITH_NATIVE_FP16
91
92
// __fp16 is a native half precision data type for arm cpu,
92
- // float16_t is an alias for __fp16 in arm_fp16.h,
93
- // which is included in arm_neon.h.
93
+ // float16_t is an alias for __fp16
94
94
HOSTDEVICE inline explicit float16 (const float16_t & h) {
95
95
x = *reinterpret_cast <const uint16_t *>(&h);
96
96
}
@@ -141,6 +141,7 @@ struct PADDLE_ALIGN(2) float16 {
141
141
return *this ;
142
142
}
143
143
144
+ // Assignment operators
144
145
#ifdef PADDLE_CUDA_FP16
145
146
HOSTDEVICE inline float16& operator =(const half& rhs) {
146
147
#if CUDA_VERSION >= 9000
@@ -219,6 +220,7 @@ struct PADDLE_ALIGN(2) float16 {
219
220
return *this ;
220
221
}
221
222
223
+ // Conversion opertors
222
224
#ifdef PADDLE_CUDA_FP16
223
225
HOSTDEVICE inline explicit operator half () const {
224
226
#if CUDA_VERSION >= 9000
@@ -353,27 +355,54 @@ struct PADDLE_ALIGN(2) float16 {
353
355
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
354
356
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
355
357
// CUDA 9.0 regarding the half data type.
356
- #if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && \
357
- __CUDA_ARCH__ >= 530 && CUDA_VERSION < 9000
358
+ #if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000
359
+
358
360
DEVICE inline half operator +(const half& a, const half& b) {
361
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
359
362
return __hadd (a, b);
363
+ #else
364
+ float res = float (float16 (a)) + float (float16 (b));
365
+ return half (float16 (res));
366
+ #endif
360
367
}
361
368
362
369
DEVICE inline half operator -(const half& a, const half& b) {
370
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
363
371
return __hsub (a, b);
372
+ #else
373
+ float res = float (float16 (a)) - float (float16 (b));
374
+ return half (float16 (res));
375
+ #endif
364
376
}
365
377
366
378
DEVICE inline half operator *(const half& a, const half& b) {
379
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
367
380
return __hmul (a, b);
381
+ #else
382
+ float res = float (float16 (a)) * float (float16 (b));
383
+ return half (float16 (res));
384
+ #endif
368
385
}
369
386
370
387
DEVICE inline half operator /(const half& a, const half& b) {
388
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
371
389
float num = __half2float (a);
372
390
float denom = __half2float (b);
373
391
return __float2half (num / denom);
392
+ #else
393
+ float res = float (float16 (a)) / float (float16 (b));
394
+ return half (float16 (res));
395
+ #endif
374
396
}
375
397
376
- DEVICE inline half operator -(const half& a) { return __hneg (a); }
398
+ DEVICE inline half operator -(const half& a) {
399
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
400
+ return __hneg (a);
401
+ #else
402
+ float res = -float (float16 (a));
403
+ return half (float16 (res));
404
+ #endif
405
+ }
377
406
378
407
DEVICE inline half& operator +=(half& a, const half& b) {
379
408
a = a + b;
@@ -396,99 +425,57 @@ DEVICE inline half& operator/=(half& a, const half& b) {
396
425
}
397
426
398
427
DEVICE inline bool operator ==(const half& a, const half& b) {
428
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
399
429
return __heq (a, b);
430
+ #else
431
+ return float (float16 (a)) == float (float16 (b));
432
+ #endif
400
433
}
401
434
402
435
DEVICE inline bool operator !=(const half& a, const half& b) {
436
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
403
437
return __hne (a, b);
438
+ #else
439
+ return float (float16 (a)) != float (float16 (b));
440
+ #endif
404
441
}
405
442
406
443
DEVICE inline bool operator <(const half& a, const half& b) {
444
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
407
445
return __hlt (a, b);
446
+ #else
447
+ return float (float16 (a)) < float (float16 (b));
448
+ #endif
408
449
}
409
450
410
451
DEVICE inline bool operator <=(const half& a, const half& b) {
452
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
411
453
return __hle (a, b);
454
+ #else
455
+ return float (float16 (a)) <= float (float16 (b));
456
+ #endif
412
457
}
413
458
414
459
DEVICE inline bool operator >(const half& a, const half& b) {
460
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
415
461
return __hgt (a, b);
462
+ #else
463
+ return float (float16 (a)) > float (float16 (b));
464
+ #endif
416
465
}
417
466
418
467
DEVICE inline bool operator >=(const half& a, const half& b) {
468
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
419
469
return __hge (a, b);
470
+ #else
471
+ return float (float16 (a)) >= float (float16 (b));
472
+ #endif
420
473
}
421
474
422
- /*
423
- DEVICE inline float16 operator+(const float16& a, const float16& b) {
424
- return float16(__hadd(half(a), half(b)));
425
- }
426
-
427
- DEVICE inline float16 operator-(const float16& a, const float16& b) {
428
- return float16(__hsub(half(a), half(b)));
429
- }
430
-
431
- DEVICE inline float16 operator*(const float16& a, const float16& b) {
432
- return float16(__hmul(half(a), half(b)));
433
- }
434
-
435
- DEVICE inline float16 operator/(const float16& a, const float16& b) {
436
- float num = __half2float(half(a));
437
- float denom = __half2float(half(b));
438
- return float16(num / denom);
439
- }
440
-
441
- DEVICE inline float16 operator-(const float16& a) {
442
- return float16(__hneg(half(a)));
443
- }
444
-
445
- DEVICE inline float16& operator+=(float16& a, const float16& b) {
446
- a = a + b;
447
- return a;
448
- }
449
-
450
- DEVICE inline float16& operator-=(float16& a, const float16& b) {
451
- a = a - b;
452
- return a;
453
- }
454
-
455
- DEVICE inline float16& operator*=(float16& a, const float16& b) {
456
- a = a * b;
457
- return a;
458
- }
459
-
460
- DEVICE inline float16& operator/=(float16& a, const float16& b) {
461
- a = a / b;
462
- return a;
463
- }
464
-
465
- DEVICE inline bool operator==(const float16& a, const float16& b) {
466
- return __heq(half(a), half(b));
467
- }
468
-
469
- DEVICE inline bool operator!=(const float16& a, const float16& b) {
470
- return __hne(half(a), half(b));
471
- }
472
-
473
- DEVICE inline bool operator<(const float16& a, const float16& b) {
474
- return __hlt(half(a), half(b));
475
- }
476
-
477
- DEVICE inline bool operator<=(const float16& a, const float16& b) {
478
- return __hle(half(a), half(b));
479
- }
480
-
481
- DEVICE inline bool operator>(const float16& a, const float16& b) {
482
- return __hgt(half(a), half(b));
483
- }
484
-
485
- DEVICE inline bool operator>=(const float16& a, const float16& b) {
486
- return __hge(half(a), half(b));
487
- }
488
- */
475
+ #endif // PADDLE_CUDA_FP16
489
476
490
477
// Arithmetic operators on ARMv8.2-A CPU
491
- #elif defined(PADDLE_WITH_NATIVE_FP16)
478
+ #if defined(PADDLE_WITH_NATIVE_FP16)
492
479
HOST inline float16 operator +(const float16& a, const float16& b) {
493
480
float16 res;
494
481
asm volatile (
@@ -681,88 +668,6 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
681
668
return (res & 0xffff ) != 0 ;
682
669
}
683
670
684
- /*
685
- HOST inline float16 operator+(const float16& a, const float16& b) {
686
- return float16(vaddh_f16(float16_t(a), float16_t(b)));
687
- }
688
-
689
- HOST inline float16 operator-(const float16& a, const float16& b) {
690
- return float16(vsubh_f16(float16_t(a), float16_t(b)));
691
- }
692
-
693
- HOST inline float16 operator*(const float16& a, const float16& b) {
694
- return float16(vmulh_f16(float16_t(a), float16_t(b)));
695
- }
696
-
697
- HOST inline float16 operator/(const float16& a, const float16& b) {
698
- return float16(vdivh_f16(float16_t(a), float16_t(b)));
699
- }
700
-
701
- HOST inline float16 operator-(const float16& a) {
702
- return float16(vnegh_f16(float16_t(a)));
703
- }
704
-
705
- HOST inline float16& operator+=(float16& a, const float16& b) {
706
- a = a + b;
707
- return a;
708
- }
709
-
710
- HOST inline float16& operator-=(float16& a, const float16& b) {
711
- a = a - b;
712
- return a;
713
- }
714
-
715
- HOST inline float16& operator*=(float16& a, const float16& b) {
716
- a = a * b;
717
- return a;
718
- }
719
-
720
- HOST inline float16& operator/=(float16& a, const float16& b) {
721
- a = a / b;
722
- return a;
723
- }
724
-
725
- HOST inline bool operator==(const float16& a, const float16& b) {
726
- return static_cast<bool>(vceqh_f16(float16_t(a), float16_t(b)));
727
- }
728
-
729
- HOST inline bool operator!=(const float16& a, const float16& b) {
730
- return !(a == b);
731
- }
732
-
733
- HOST inline bool operator<(const float16& a, const float16& b) {
734
- #ifdef PADDLE_NEON_64
735
- return static_cast<bool>(vclth_f16(float16_t(a), float16_t(b)));
736
- #else
737
- return float(a) < float(b);
738
- #endif // PADDLE_NEON_64
739
- }
740
-
741
- HOST inline bool operator<=(const float16& a, const float16& b) {
742
- #ifdef PADDLE_NEON_64
743
- return static_cast<bool>(vcleh_f16(float16_t(a), float16_t(b)));
744
- #else
745
- return float(a) <= float(b);
746
- #endif // PADDLE_NEON_64
747
- }
748
-
749
- HOST inline bool operator>(const float16& a, const float16& b) {
750
- #ifdef PADDLE_NEON_64
751
- return static_cast<bool>(vcgth_f16(float16_t(a), float16_t(b)));
752
- #else
753
- return float(a) > float(b);
754
- #endif // PADDLE_NEON_64
755
- }
756
-
757
- HOST inline bool operator>=(const float16& a, const float16& b) {
758
- #ifdef PADDLE_NEON_64
759
- return static_cast<bool>(vcgeh_f16(float16_t(a), float16_t(b)));
760
- #else
761
- return float(a) >= float(b);
762
- #endif // PADDLE_NEON_64
763
- }
764
- */
765
-
766
671
// Arithmetic operators, software emulated on other CPU
767
672
#else
768
673
HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
0 commit comments