Skip to content

Commit 1af9eb1

Browse files
authored
[CINN] [New Hardware Update]:Add CINN HIP fp16 support (#74361)
* add cinn hip fp16 support * fix bugs * Update hip_intrinsics_float16.cc
1 parent 1e11628 commit 1af9eb1

File tree

6 files changed

+336
-37
lines changed

6 files changed

+336
-37
lines changed

paddle/cinn/backends/hip/codegen_hip_dev.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ namespace backends {
1919
namespace hip {
2020

2121
const std::string CodeGenHipDevice::source_header_ = // NOLINT
22-
R"(#include "cinn_hip_runtime_source.h"
22+
R"(#define CINN_WITH_HIP
23+
#include "float16.h"
24+
using cinn::common::float16;
25+
#include "cinn_hip_runtime_source.h"
2326
)";
2427

2528
const std::string &CodeGenHipDevice::GetSourceHeader() {

paddle/cinn/common/float16.h

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
#endif // __CUDACC__
4141
#endif // CINN_WITH_CUDA
4242

43+
#ifdef CINN_WITH_HIP
44+
#include <hip/hip_runtime.h>
45+
#if defined(__HIPCC__)
46+
#define __HIP_PLATFORM_AMD__
47+
#include <hip/hip_fp16.h>
48+
#define CINN_HIP_FP16
49+
#endif
50+
#endif
51+
4352
#ifdef __cplusplus
4453
#ifndef _WIN32
4554
#define CINN_ALIGN(x) __attribute__((aligned(x)))
@@ -83,7 +92,7 @@ struct CINN_ALIGN(2) float16 {
8392
~float16() = default;
8493

8594
// Constructors
86-
#ifdef CINN_CUDA_FP16
95+
#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
8796
__host__ __device__ inline explicit float16(const half& h) {
8897
#if (CUDA_VERSION >= 9000)
8998
x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
@@ -129,9 +138,9 @@ struct CINN_ALIGN(2) float16 {
129138
: x(float16(static_cast<float>(val)).x) {}
130139

131140
// Assignment operators
132-
#ifdef CINN_CUDA_FP16
141+
#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
133142
__host__ __device__ inline float16& operator=(const half& rhs) {
134-
#if CUDA_VERSION >= 9000
143+
#if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
135144
x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
136145
#else
137146
x = rhs.x;
@@ -196,9 +205,9 @@ struct CINN_ALIGN(2) float16 {
196205
}
197206

198207
// Conversion operators
199-
#ifdef CINN_CUDA_FP16
208+
#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
200209
__host__ __device__ inline half to_half() const {
201-
#if CUDA_VERSION >= 9000
210+
#if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
202211
__half_raw h;
203212
h.x = x;
204213
return half(h);
@@ -211,7 +220,9 @@ struct CINN_ALIGN(2) float16 {
211220
#endif // CINN_CUDA_FP16
212221

213222
__host__ __device__ inline operator float() const {
214-
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
223+
#if defined(CINN_CUDA_FP16) && \
224+
(defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \
225+
defined(CINN_HIP_FP16)
215226
half tmp = *reinterpret_cast<const half*>(this);
216227
return __half2float(tmp);
217228

@@ -344,9 +355,9 @@ struct CINN_ALIGN(4) float162 {
344355
// CUDA 9.0 regarding the half data type.
345356
// ROCM has built-in arithmetic operators as not defined
346357
// __HIP_NO_HALF_OPERATORS__
347-
#if defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000
358+
#if (defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000) || defined(CINN_HIP_FP16)
348359
__device__ inline half operator+(const half& a, const half& b) {
349-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
360+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
350361
return __hadd(a, b);
351362
#else
352363
float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
@@ -355,7 +366,7 @@ __device__ inline half operator+(const half& a, const half& b) {
355366
}
356367

357368
__device__ inline half operator-(const half& a, const half& b) {
358-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
369+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
359370
return __hsub(a, b);
360371
#else
361372
float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
@@ -364,7 +375,7 @@ __device__ inline half operator-(const half& a, const half& b) {
364375
}
365376

366377
__device__ inline half operator*(const half& a, const half& b) {
367-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
378+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
368379
return __hmul(a, b);
369380
#else
370381
float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
@@ -373,7 +384,7 @@ __device__ inline half operator*(const half& a, const half& b) {
373384
}
374385

375386
__device__ inline half operator/(const half& a, const half& b) {
376-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
387+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
377388
float num = __half2float(a);
378389
float denom = __half2float(b);
379390
return __float2half(num / denom);
@@ -384,14 +395,15 @@ __device__ inline half operator/(const half& a, const half& b) {
384395
}
385396

386397
__device__ inline half operator-(const half& a) {
387-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
398+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
388399
return __hneg(a);
389400
#else
390401
float res = -static_cast<float>(float16(a));
391402
return float16(res).to_half();
392403
#endif
393404
}
394405

406+
#ifndef CINN_WITH_HIP
395407
__device__ inline half& operator+=(half& a, const half& b) { // NOLINT
396408
a = a + b;
397409
return a;
@@ -411,49 +423,50 @@ __device__ inline half& operator/=(half& a, const half& b) { // NOLINT
411423
a = a / b;
412424
return a;
413425
}
426+
#endif
414427

415428
__device__ inline bool operator==(const half& a, const half& b) {
416-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
429+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
417430
return __heq(a, b);
418431
#else
419432
return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
420433
#endif
421434
}
422435

423436
__device__ inline bool operator!=(const half& a, const half& b) {
424-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
437+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
425438
return __hne(a, b);
426439
#else
427440
return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
428441
#endif
429442
}
430443

431444
__device__ inline bool operator<(const half& a, const half& b) {
432-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
445+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
433446
return __hlt(a, b);
434447
#else
435448
return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
436449
#endif
437450
}
438451

439452
__device__ inline bool operator<=(const half& a, const half& b) {
440-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
453+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
441454
return __hle(a, b);
442455
#else
443456
return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
444457
#endif
445458
}
446459

447460
__device__ inline bool operator>(const half& a, const half& b) {
448-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
461+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
449462
return __hgt(a, b);
450463
#else
451464
return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
452465
#endif
453466
}
454467

455468
__device__ inline bool operator>=(const half& a, const half& b) {
456-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
469+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
457470
return __hge(a, b);
458471
#else
459472
return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
@@ -465,7 +478,9 @@ __device__ inline bool operator>=(const half& a, const half& b) {
465478
// Arithmetic operators for float16 on GPU
466479
__host__ __device__ inline float16 operator+(const float16& a,
467480
const float16& b) {
468-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
481+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
482+
__CUDA_ARCH__ >= 530) || \
483+
defined(CINN_HIP_FP16)
469484
return float16(__hadd(a.to_half(), b.to_half()));
470485
#else
471486
return float16(static_cast<float>(a) + static_cast<float>(b));
@@ -474,7 +489,9 @@ __host__ __device__ inline float16 operator+(const float16& a,
474489

475490
__host__ __device__ inline float16 operator-(const float16& a,
476491
const float16& b) {
477-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
492+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
493+
__CUDA_ARCH__ >= 530) || \
494+
defined(CINN_HIP_FP16)
478495
return float16(__hsub(a.to_half(), b.to_half()));
479496
#else
480497
return float16(static_cast<float>(a) - static_cast<float>(b));
@@ -483,7 +500,9 @@ __host__ __device__ inline float16 operator-(const float16& a,
483500

484501
__host__ __device__ inline float16 operator*(const float16& a,
485502
const float16& b) {
486-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
503+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
504+
__CUDA_ARCH__ >= 530) || \
505+
defined(CINN_HIP_FP16)
487506
return float16(__hmul(a.to_half(), b.to_half()));
488507
#else
489508
return float16(static_cast<float>(a) * static_cast<float>(b));
@@ -492,7 +511,9 @@ __host__ __device__ inline float16 operator*(const float16& a,
492511

493512
__host__ __device__ inline float16 operator/(const float16& a,
494513
const float16& b) {
495-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
514+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
515+
__CUDA_ARCH__ >= 530) || \
516+
defined(CINN_HIP_FP16)
496517
// TODO(kexinzhao): check which cuda version starts to support __hdiv
497518
float num = __half2float(a.to_half());
498519
float denom = __half2float(b.to_half());
@@ -503,7 +524,9 @@ __host__ __device__ inline float16 operator/(const float16& a,
503524
}
504525

505526
__host__ __device__ inline float16 operator-(const float16& a) {
506-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
527+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
528+
__CUDA_ARCH__ >= 530) || \
529+
defined(CINN_HIP_FP16)
507530
return float16(__hneg(a.to_half()));
508531
#else
509532
float16 res;
@@ -537,47 +560,59 @@ __host__ __device__ inline float16& operator/=(float16& a, // NOLINT
537560
}
538561

539562
__host__ __device__ inline bool operator==(const float16& a, const float16& b) {
540-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
563+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
564+
__CUDA_ARCH__ >= 530) || \
565+
defined(CINN_HIP_FP16)
541566
return __heq(a.to_half(), b.to_half());
542567
#else
543568
return static_cast<float>(a) == static_cast<float>(b);
544569
#endif
545570
}
546571

547572
__host__ __device__ inline bool operator!=(const float16& a, const float16& b) {
548-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
573+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
574+
__CUDA_ARCH__ >= 530) || \
575+
defined(CINN_HIP_FP16)
549576
return __hne(a.to_half(), b.to_half());
550577
#else
551578
return static_cast<float>(a) != static_cast<float>(b);
552579
#endif
553580
}
554581

555582
__host__ __device__ inline bool operator<(const float16& a, const float16& b) {
556-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
583+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
584+
__CUDA_ARCH__ >= 530) || \
585+
defined(CINN_HIP_FP16)
557586
return __hlt(a.to_half(), b.to_half());
558587
#else
559588
return static_cast<float>(a) < static_cast<float>(b);
560589
#endif
561590
}
562591

563592
__host__ __device__ inline bool operator<=(const float16& a, const float16& b) {
564-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
593+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
594+
__CUDA_ARCH__ >= 530) || \
595+
defined(CINN_HIP_FP16)
565596
return __hle(a.to_half(), b.to_half());
566597
#else
567598
return static_cast<float>(a) <= static_cast<float>(b);
568599
#endif
569600
}
570601

571602
__host__ __device__ inline bool operator>(const float16& a, const float16& b) {
572-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
603+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
604+
__CUDA_ARCH__ >= 530) || \
605+
defined(CINN_HIP_FP16)
573606
return __hgt(a.to_half(), b.to_half());
574607
#else
575608
return static_cast<float>(a) > static_cast<float>(b);
576609
#endif
577610
}
578611

579612
__host__ __device__ inline bool operator>=(const float16& a, const float16& b) {
580-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
613+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
614+
__CUDA_ARCH__ >= 530) || \
615+
defined(CINN_HIP_FP16)
581616
return __hge(a.to_half(), b.to_half());
582617
#else
583618
return static_cast<float>(a) >= static_cast<float>(b);
@@ -592,7 +627,9 @@ __host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) {
592627
}
593628

594629
__host__ __device__ inline bool(isnan)(const float16& a) {
595-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
630+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
631+
__CUDA_ARCH__ >= 530) || \
632+
defined(CINN_HIP_FP16)
596633
return __hisnan(a.to_half());
597634
#else
598635
return (a.x & 0x7fff) > 0x7c00;
@@ -608,7 +645,9 @@ __host__ __device__ inline bool(isfinite)(const float16& a) {
608645
}
609646

610647
__host__ __device__ inline float16(abs)(const float16& a) {
611-
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
648+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
649+
__CUDA_ARCH__ >= 530) || \
650+
defined(CINN_HIP_FP16)
612651
return static_cast<float16>(__habs(a.to_half()));
613652
#else
614653
return static_cast<float16>(fabsf(static_cast<float>(a)));

paddle/cinn/runtime/hip/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ gather_srcs(
77
hip_backend_api.cc
88
hip_module.cc
99
hip_intrinsics.cc
10-
hip_intrinsics_reduce.cc)
10+
hip_intrinsics_reduce.cc
11+
hip_intrinsics_float16.cc)

0 commit comments

Comments
 (0)