40
40
#endif // __CUDACC__
41
41
#endif // CINN_WITH_CUDA
42
42
43
+ #ifdef CINN_WITH_HIP
44
+ #include < hip/hip_runtime.h>
45
+ #if defined(__HIPCC__)
46
+ #define __HIP_PLATFORM_AMD__
47
+ #include < hip/hip_fp16.h>
48
+ #define CINN_HIP_FP16
49
+ #endif
50
+ #endif
51
+
43
52
#ifdef __cplusplus
44
53
#ifndef _WIN32
45
54
#define CINN_ALIGN (x ) __attribute__((aligned(x)))
@@ -83,7 +92,7 @@ struct CINN_ALIGN(2) float16 {
83
92
~float16 () = default ;
84
93
85
94
// Constructors
86
- #ifdef CINN_CUDA_FP16
95
+ #if defined( CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
87
96
__host__ __device__ inline explicit float16 (const half& h) {
88
97
#if (CUDA_VERSION >= 9000)
89
98
x = reinterpret_cast <__half_raw*>(const_cast <half*>(&h))->x ;
@@ -129,9 +138,9 @@ struct CINN_ALIGN(2) float16 {
129
138
: x (float16 (static_cast <float >(val)).x ) {}
130
139
131
140
// Assignment operators
132
- #ifdef CINN_CUDA_FP16
141
+ #if defined( CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
133
142
__host__ __device__ inline float16& operator =(const half& rhs) {
134
- #if CUDA_VERSION >= 9000
143
+ #if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
135
144
x = reinterpret_cast <__half_raw*>(const_cast <half*>(&rhs))->x ;
136
145
#else
137
146
x = rhs.x ;
@@ -196,9 +205,9 @@ struct CINN_ALIGN(2) float16 {
196
205
}
197
206
198
207
// Conversion operators
199
- #ifdef CINN_CUDA_FP16
208
+ #if defined( CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
200
209
__host__ __device__ inline half to_half () const {
201
- #if CUDA_VERSION >= 9000
210
+ #if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
202
211
__half_raw h;
203
212
h.x = x;
204
213
return half (h);
@@ -211,7 +220,9 @@ struct CINN_ALIGN(2) float16 {
211
220
#endif // CINN_CUDA_FP16
212
221
213
222
__host__ __device__ inline operator float () const {
214
- #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
223
+ #if defined(CINN_CUDA_FP16) && \
224
+ (defined (__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 ) || \
225
+ defined (CINN_HIP_FP16)
215
226
half tmp = *reinterpret_cast <const half*>(this );
216
227
return __half2float (tmp);
217
228
@@ -344,9 +355,9 @@ struct CINN_ALIGN(4) float162 {
344
355
// CUDA 9.0 regarding the half data type.
345
356
// ROCM has built-in arithmetic operators as not defined
346
357
// __HIP_NO_HALF_OPERATORS__
347
- #if defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000
358
+ #if ( defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000) || defined(CINN_HIP_FP16)
348
359
__device__ inline half operator +(const half& a, const half& b) {
349
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
360
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
350
361
return __hadd (a, b);
351
362
#else
352
363
float res = static_cast <float >(float16 (a)) + static_cast <float >(float16 (b));
@@ -355,7 +366,7 @@ __device__ inline half operator+(const half& a, const half& b) {
355
366
}
356
367
357
368
__device__ inline half operator -(const half& a, const half& b) {
358
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
369
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
359
370
return __hsub (a, b);
360
371
#else
361
372
float res = static_cast <float >(float16 (a)) - static_cast <float >(float16 (b));
@@ -364,7 +375,7 @@ __device__ inline half operator-(const half& a, const half& b) {
364
375
}
365
376
366
377
__device__ inline half operator *(const half& a, const half& b) {
367
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
378
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
368
379
return __hmul (a, b);
369
380
#else
370
381
float res = static_cast <float >(float16 (a)) * static_cast <float >(float16 (b));
@@ -373,7 +384,7 @@ __device__ inline half operator*(const half& a, const half& b) {
373
384
}
374
385
375
386
__device__ inline half operator /(const half& a, const half& b) {
376
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
387
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
377
388
float num = __half2float (a);
378
389
float denom = __half2float (b);
379
390
return __float2half (num / denom);
@@ -384,14 +395,15 @@ __device__ inline half operator/(const half& a, const half& b) {
384
395
}
385
396
386
397
__device__ inline half operator -(const half& a) {
387
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
398
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
388
399
return __hneg (a);
389
400
#else
390
401
float res = -static_cast <float >(float16 (a));
391
402
return float16 (res).to_half ();
392
403
#endif
393
404
}
394
405
406
+ #ifndef CINN_WITH_HIP
395
407
__device__ inline half& operator +=(half& a, const half& b) { // NOLINT
396
408
a = a + b;
397
409
return a;
@@ -411,49 +423,50 @@ __device__ inline half& operator/=(half& a, const half& b) { // NOLINT
411
423
a = a / b;
412
424
return a;
413
425
}
426
+ #endif
414
427
415
428
__device__ inline bool operator ==(const half& a, const half& b) {
416
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
429
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
417
430
return __heq (a, b);
418
431
#else
419
432
return static_cast <float >(float16 (a)) == static_cast <float >(float16 (b));
420
433
#endif
421
434
}
422
435
423
436
__device__ inline bool operator !=(const half& a, const half& b) {
424
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
437
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
425
438
return __hne (a, b);
426
439
#else
427
440
return static_cast <float >(float16 (a)) != static_cast <float >(float16 (b));
428
441
#endif
429
442
}
430
443
431
444
__device__ inline bool operator <(const half& a, const half& b) {
432
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
445
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
433
446
return __hlt (a, b);
434
447
#else
435
448
return static_cast <float >(float16 (a)) < static_cast <float >(float16 (b));
436
449
#endif
437
450
}
438
451
439
452
__device__ inline bool operator <=(const half& a, const half& b) {
440
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
453
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
441
454
return __hle (a, b);
442
455
#else
443
456
return static_cast <float >(float16 (a)) <= static_cast <float >(float16 (b));
444
457
#endif
445
458
}
446
459
447
460
__device__ inline bool operator >(const half& a, const half& b) {
448
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
461
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
449
462
return __hgt (a, b);
450
463
#else
451
464
return static_cast <float >(float16 (a)) > static_cast <float >(float16 (b));
452
465
#endif
453
466
}
454
467
455
468
__device__ inline bool operator >=(const half& a, const half& b) {
456
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
469
+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
457
470
return __hge (a, b);
458
471
#else
459
472
return static_cast <float >(float16 (a)) >= static_cast <float >(float16 (b));
@@ -465,7 +478,9 @@ __device__ inline bool operator>=(const half& a, const half& b) {
465
478
// Arithmetic operators for float16 on GPU
466
479
__host__ __device__ inline float16 operator +(const float16& a,
467
480
const float16& b) {
468
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
481
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
482
+ __CUDA_ARCH__ >= 530 ) || \
483
+ defined (CINN_HIP_FP16)
469
484
return float16 (__hadd (a.to_half (), b.to_half ()));
470
485
#else
471
486
return float16 (static_cast <float >(a) + static_cast <float >(b));
@@ -474,7 +489,9 @@ __host__ __device__ inline float16 operator+(const float16& a,
474
489
475
490
__host__ __device__ inline float16 operator -(const float16& a,
476
491
const float16& b) {
477
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
492
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
493
+ __CUDA_ARCH__ >= 530 ) || \
494
+ defined (CINN_HIP_FP16)
478
495
return float16 (__hsub (a.to_half (), b.to_half ()));
479
496
#else
480
497
return float16 (static_cast <float >(a) - static_cast <float >(b));
@@ -483,7 +500,9 @@ __host__ __device__ inline float16 operator-(const float16& a,
483
500
484
501
__host__ __device__ inline float16 operator *(const float16& a,
485
502
const float16& b) {
486
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
503
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
504
+ __CUDA_ARCH__ >= 530 ) || \
505
+ defined (CINN_HIP_FP16)
487
506
return float16 (__hmul (a.to_half (), b.to_half ()));
488
507
#else
489
508
return float16 (static_cast <float >(a) * static_cast <float >(b));
@@ -492,7 +511,9 @@ __host__ __device__ inline float16 operator*(const float16& a,
492
511
493
512
__host__ __device__ inline float16 operator /(const float16& a,
494
513
const float16& b) {
495
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
514
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
515
+ __CUDA_ARCH__ >= 530 ) || \
516
+ defined (CINN_HIP_FP16)
496
517
// TODO(kexinzhao): check which cuda version starts to support __hdiv
497
518
float num = __half2float (a.to_half ());
498
519
float denom = __half2float (b.to_half ());
@@ -503,7 +524,9 @@ __host__ __device__ inline float16 operator/(const float16& a,
503
524
}
504
525
505
526
__host__ __device__ inline float16 operator -(const float16& a) {
506
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
527
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
528
+ __CUDA_ARCH__ >= 530 ) || \
529
+ defined (CINN_HIP_FP16)
507
530
return float16 (__hneg (a.to_half ()));
508
531
#else
509
532
float16 res;
@@ -537,47 +560,59 @@ __host__ __device__ inline float16& operator/=(float16& a, // NOLINT
537
560
}
538
561
539
562
__host__ __device__ inline bool operator ==(const float16& a, const float16& b) {
540
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
563
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
564
+ __CUDA_ARCH__ >= 530 ) || \
565
+ defined (CINN_HIP_FP16)
541
566
return __heq (a.to_half (), b.to_half ());
542
567
#else
543
568
return static_cast <float >(a) == static_cast <float >(b);
544
569
#endif
545
570
}
546
571
547
572
__host__ __device__ inline bool operator !=(const float16& a, const float16& b) {
548
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
573
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
574
+ __CUDA_ARCH__ >= 530 ) || \
575
+ defined (CINN_HIP_FP16)
549
576
return __hne (a.to_half (), b.to_half ());
550
577
#else
551
578
return static_cast <float >(a) != static_cast <float >(b);
552
579
#endif
553
580
}
554
581
555
582
__host__ __device__ inline bool operator <(const float16& a, const float16& b) {
556
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
583
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
584
+ __CUDA_ARCH__ >= 530 ) || \
585
+ defined (CINN_HIP_FP16)
557
586
return __hlt (a.to_half (), b.to_half ());
558
587
#else
559
588
return static_cast <float >(a) < static_cast <float >(b);
560
589
#endif
561
590
}
562
591
563
592
__host__ __device__ inline bool operator <=(const float16& a, const float16& b) {
564
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
593
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
594
+ __CUDA_ARCH__ >= 530 ) || \
595
+ defined (CINN_HIP_FP16)
565
596
return __hle (a.to_half (), b.to_half ());
566
597
#else
567
598
return static_cast <float >(a) <= static_cast <float >(b);
568
599
#endif
569
600
}
570
601
571
602
__host__ __device__ inline bool operator >(const float16& a, const float16& b) {
572
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
603
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
604
+ __CUDA_ARCH__ >= 530 ) || \
605
+ defined (CINN_HIP_FP16)
573
606
return __hgt (a.to_half (), b.to_half ());
574
607
#else
575
608
return static_cast <float >(a) > static_cast <float >(b);
576
609
#endif
577
610
}
578
611
579
612
__host__ __device__ inline bool operator >=(const float16& a, const float16& b) {
580
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
613
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
614
+ __CUDA_ARCH__ >= 530 ) || \
615
+ defined (CINN_HIP_FP16)
581
616
return __hge (a.to_half (), b.to_half ());
582
617
#else
583
618
return static_cast <float >(a) >= static_cast <float >(b);
@@ -592,7 +627,9 @@ __host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) {
592
627
}
593
628
594
629
__host__ __device__ inline bool (isnan)(const float16& a) {
595
- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
630
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
631
+ __CUDA_ARCH__ >= 530 ) || \
632
+ defined (CINN_HIP_FP16)
596
633
return __hisnan (a.to_half ());
597
634
#else
598
635
return (a.x & 0x7fff ) > 0x7c00 ;
@@ -608,7 +645,9 @@ __host__ __device__ inline bool(isfinite)(const float16& a) {
608
645
}
609
646
610
647
__host__ __device__ inline float16 (abs)(const float16& a) {
611
- #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
648
+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
649
+ __CUDA_ARCH__ >= 530 ) || \
650
+ defined (CINN_HIP_FP16)
612
651
return static_cast <float16>(__habs (a.to_half ()));
613
652
#else
614
653
return static_cast <float16>(fabsf (static_cast <float >(a)));
0 commit comments