Skip to content

Commit fc4d03e

Browse files
committed
Support gfx*-generic targets
1 parent c1cd4a7 commit fc4d03e

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ namespace wmma = mtmusa::wmma;
1515
namespace wmma = nvcuda::wmma;
1616
#endif // GGML_USE_MUSA
1717
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
18+
// Workaround for gfx*-generic
19+
#if defined(__gfx11_generic__)
20+
#define __gfx1100__ __gfx11_generic__
21+
#elif defined(__gfx12_generic__)
22+
#define __gfx1201__ __gfx12_generic__
23+
#endif
1824
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1925
#include <rocwmma/rocwmma.hpp>
2026
namespace wmma = rocwmma;

ggml/src/ggml-cuda/sum.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
#ifdef USE_CUB
66

77
#if defined(GGML_USE_HIP)
8+
// Workaround for gfx*-generic
9+
#if defined(__gfx10_1_generic__)
10+
#define __gfx1010__ __gfx10_1_generic__
11+
#elif defined(__gfx10_3_generic__)
12+
#define __gfx1030__ __gfx10_3_generic__
13+
#elif defined(__gfx11_generic__)
14+
#define __gfx1100__ __gfx11_generic__
15+
#elif defined(__gfx12_generic__)
16+
#define __gfx1201__ __gfx12_generic__
17+
#endif
18+
819
#include <hipcub/hipcub.hpp>
920
using namespace hipcub;
1021
#else

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,17 @@
167167
#define RDNA4
168168
#endif
169169

170-
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
171-
defined(__gfx1150__) || defined(__gfx1151__)
170+
#if defined(__GFX11__)
172171
#define RDNA3
173172
#endif
174173

175174
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
176-
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
175+
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) || \
176+
defined(__gfx10_3_generic__)
177177
#define RDNA2
178178
#endif
179179

180-
#if defined(__gfx1010__) || defined(__gfx1012__)
180+
#if defined(__gfx1010__) || defined(__gfx1012__) || defined(__gfx10_1_generic__)
181181
#define RDNA1
182182
#endif
183183

0 commit comments

Comments
 (0)