Skip to content

Commit 6c5c8be

Browse files
committed
try to make rocm work for the github ci, requires disabling rocwmma
1 parent 7f57846 commit 6c5c8be

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

.github/workflows/kcpp-build-release-linux-rocm.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ env:
1212
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
1313
KCPP_CUDA: rocm
1414
ARCHES_CU12: 1
15+
NO_WMMA: 1
1516

1617
jobs:
1718
linux:

Makefile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,21 +244,25 @@ ifdef LLAMA_HIPBLAS
244244
ifeq ($(wildcard /opt/rocm),)
245245
ROCM_PATH ?= /usr
246246
ifdef LLAMA_PORTABLE
247-
GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 $(shell $(shell which amdgpu-arch))
247+
GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 gfx1200 gfx1201 $(shell $(shell which amdgpu-arch))
248248
else
249249
GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
250250
endif
251251
HCC := $(ROCM_PATH)/bin/hipcc
252252
HCXX := $(ROCM_PATH)/bin/hipcc
253253
else
254254
ROCM_PATH ?= /opt/rocm
255-
GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
255+
GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 gfx1200 gfx1201 $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
256256
HCC := $(ROCM_PATH)/llvm/bin/clang
257257
HCXX := $(ROCM_PATH)/llvm/bin/clang++
258258
endif
259+
ifdef LLAMA_NO_WMMA
260+
HIPFLAGS += -DGGML_HIP_NO_ROCWMMA_FATTN
261+
else
259262
DETECT_ROCWMMA := $(shell find -L /opt/rocm/include /usr/include -type f -name rocwmma.hpp 2>/dev/null | head -n 1)
260263
ifdef DETECT_ROCWMMA
261264
HIPFLAGS += -DGGML_HIP_ROCWMMA_FATTN -I$(dir $(DETECT_ROCWMMA))
265+
endif
262266
endif
263267

264268
HIPFLAGS += -DGGML_USE_HIP -DGGML_HIP_NO_VMM -DGGML_USE_CUDA -DSD_USE_CUDA $(shell $(ROCM_PATH)/bin/hipconfig -C)

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
5656
const int ne1,
5757
const int ne2,
5858
const int ne3) {
59-
#if defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
59+
#if !defined(GGML_HIP_NO_ROCWMMA_FATTN) && defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
6060
// Skip unused kernel variants for faster compilation:
6161
if (use_logit_softcap && !(D == 128 || D == 256)) {
6262
NO_DEVICE_CODE;

koboldcpp.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ KCPP_CUDAAPPEND=-cuda${KCPP_CUDA//.}$KCPP_APPEND
2929

3030
LLAMA_NOAVX2_FLAG=""
3131
ARCHES_FLAG=""
32+
NO_WMMA_FLAG=""
3233
if [ -n "$NOAVX2" ]; then
3334
LLAMA_NOAVX2_FLAG="LLAMA_NOAVX2=1"
3435
fi
@@ -38,11 +39,14 @@ fi
3839
if [ -n "$ARCHES_CU12" ]; then
3940
ARCHES_FLAG="LLAMA_ARCHES_CU12=1"
4041
fi
42+
if [ -n "$NO_WMMA" ]; then
43+
NO_WMMA_FLAG="LLAMA_NO_WMMA=1"
44+
fi
4145

4246
if [ "$KCPP_CUDA" = "rocm" ]; then
43-
bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_HIPBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG
47+
bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_HIPBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG $NO_WMMA_FLAG
4448
else
45-
bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_CUBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG
49+
bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_CUBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG $NO_WMMA_FLAG
4650
fi
4751

4852
if [ $? -ne 0 ]; then

0 commit comments

Comments
 (0)