Skip to content

Commit ecb468d

Browse files
author
hzhang13
committed
fix compiler error
1 parent 7713abe commit ecb468d

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,30 @@ namespace ggml_cuda_mma {
7070
static constexpr int J = J_;
7171

7272
#if defined(GGML_USE_HIP)
73-
#if defined(CDNA)
73+
#if defined(RDNA4)
74+
static constexpr int ne = I * J / 32;
75+
T x[ne] = {0};
76+
77+
static __device__ __forceinline__ int get_i(const int l) {
78+
if constexpr (I == 16 && J == 16) {
79+
return 8 * (threadIdx.x / 16) + l;
80+
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
81+
return 4 * (threadIdx.x / 16) + l;
82+
} else {
83+
static_assert(I == -1 && J == -1, "template specialization not implemented");
84+
}
85+
}
86+
87+
static __device__ __forceinline__ int get_j(const int l) {
88+
if constexpr (I == 16 && J == 16) {
89+
return threadIdx.x % 16;
90+
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
91+
return threadIdx.x % 16;
92+
} else {
93+
static_assert(I == -1 && J == -1, "template specialization not implemented");
94+
}
95+
}
96+
#else
7497
static constexpr int ne = I * J / 64;
7598
T x[ne] = {0};
7699

@@ -105,30 +128,7 @@ namespace ggml_cuda_mma {
105128
static_assert(I == -1 && J == -1, "template specialization not implemented");
106129
}
107130
}
108-
#elif defined(RDNA4)
109-
static constexpr int ne = I * J / 32;
110-
T x[ne] = {0};
111-
112-
static __device__ __forceinline__ int get_i(const int l) {
113-
if constexpr (I == 16 && J == 16) {
114-
return 8 * (threadIdx.x / 16) + l;
115-
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
116-
return 4 * (threadIdx.x / 16) + l;
117-
} else {
118-
static_assert(I == -1 && J == -1, "template specialization not implemented");
119-
}
120-
}
121-
122-
static __device__ __forceinline__ int get_j(const int l) {
123-
if constexpr (I == 16 && J == 16) {
124-
return threadIdx.x % 16;
125-
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
126-
return threadIdx.x % 16;
127-
} else {
128-
static_assert(I == -1 && J == -1, "template specialization not implemented");
129-
}
130-
}
131-
#endif // defined(CDNA)
131+
#endif // defined(RDNA4)
132132
#else
133133
static constexpr int ne = I * J / 32;
134134
T x[ne] = {0};

0 commit comments

Comments
 (0)