Skip to content

Commit efd21bf

Browse files
author
Huaishun Hu
committed
musa: fix __dp4a incorrect result
1 parent 6afb592 commit efd21bf

File tree

5 files changed

+31
-24
lines changed

5 files changed

+31
-24
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ ifdef GGML_MUSA
847847
CXX := $(MUSA_PATH)/bin/clang++
848848
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc
849849

850-
MUSAFLAGS = -x musa -mtgpu
850+
MUSAFLAGS = -fsigned-char -x musa -mtgpu
851851
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))
852852

853853
ifdef GGML_CUDA_FORCE_MMQ

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
353353

354354
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
355355

356-
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
356+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
357357
return __dp4a(a, b, c);
358358
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
359359
const int8_t * a8 = (const int8_t *) &a;

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
136136
return true;
137137
}
138138

139+
#if defined(GGML_USE_MUSA)
140+
return true;
141+
#endif
142+
139143
if (cc < GGML_CUDA_CC_DP4A) {
140144
return false;
141145
}

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -137,32 +137,35 @@
137137
typedef mt_bfloat16 nv_bfloat16;
138138

139139
/** FIXME: MUSA arch should match CUDA 11.4 */
140-
// #define CC_OFFSET_MT 99999 // should < CC_OFFSET_AMD
141-
// #define __CUDA_ARCH__ CC_OFFSET_MT
142-
// #define __CUDA_ARCH__ 800 // AMPERE
140+
// #define GGML_CUDA_CC_PASCAL 600
141+
// #define GGML_CUDA_CC_DP4A 610
142+
// #define GGML_CUDA_CC_VOLTA 700
143+
// #define GGML_CUDA_CC_TURING 750
144+
// #define GGML_CUDA_CC_AMPERE 800
143145

144-
#define __MUSA_CC__ 800
146+
#define __MUSA_CC__ 610
147+
// #define __CUDA_ARCH__ __MUSA_CC__
145148

146149

147-
/** TODO: following apis not supported yet by musa sdk: ***********
150+
/** TODO: following apis not supported yet by musa sdk: ***********/
148151

149-
__device__ __half hexp(const __half a) {
150-
float f_a = __half2float(a);
151-
float f_result = expf(f_a);
152-
return __float2half(f_result);
153-
}
152+
// __device__ __half hexp(const __half a) {
153+
// float f_a = __half2float(a);
154+
// float f_result = expf(f_a);
155+
// return __float2half(f_result);
156+
// }
154157

155-
__host__ __device__ __half2 h2exp(const __half2 a) {
156-
// Extract lower and upper halves
157-
__half lower = __low2half(a);
158-
__half upper = __high2half(a);
158+
// __host__ __device__ __half2 h2exp(const __half2 a) {
159+
// // Extract lower and upper halves
160+
// __half lower = __low2half(a);
161+
// __half upper = __high2half(a);
159162

160-
// Compute exp for each half
161-
__half exp_lower = hexp(lower);
162-
__half exp_upper = hexp(upper);
163+
// // Compute exp for each half
164+
// __half exp_lower = hexp(lower);
165+
// __half exp_upper = hexp(upper);
163166

164-
// Combine back into __half2
165-
return __halves2half2(exp_lower, exp_upper);
166-
}
167+
// // Combine back into __half2
168+
// return __halves2half2(exp_lower, exp_upper);
169+
// }
167170

168-
******************************************************************/
171+
/******************************************************************/

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)
4949

5050
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
5151
foreach(SOURCE ${GGML_SOURCES_MUSA})
52-
set(COMPILE_FLAGS "-x musa -mtgpu")
52+
set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
5353
foreach(ARCH ${MUSA_ARCHITECTURES})
5454
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
5555
endforeach()

0 commit comments

Comments
 (0)