@@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND)
2828 list (APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h" )
2929
3030 file (GLOB GGML_SOURCES_CUDA "*.cu" )
31- file (GLOB SRCS "template-instances/fattn-wmma*.cu" )
32- list (APPEND GGML_SOURCES_CUDA ${SRCS} )
33- file (GLOB SRCS "template-instances/mmq*.cu" )
34- list (APPEND GGML_SOURCES_CUDA ${SRCS} )
35-
36- if (GGML_CUDA_FA_ALL_QUANTS)
37- file (GLOB SRCS "template-instances/fattn-vec*.cu" )
31+ if (GGML_CUDA_FA)
32+ file (GLOB SRCS "template-instances/fattn-wmma*.cu" )
3833 list (APPEND GGML_SOURCES_CUDA ${SRCS} )
39- add_compile_definitions (GGML_CUDA_FA_ALL_QUANTS)
4034 else ()
41- file (GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu" )
42- list (APPEND GGML_SOURCES_CUDA ${SRCS} )
43- file (GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu" )
44- list (APPEND GGML_SOURCES_CUDA ${SRCS} )
45- file (GLOB SRCS "template-instances/fattn-vec*f16-f16.cu" )
35+ list (FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*" )
36+ list (FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*" )
37+ # message(FATAL_ERROR ${GGML_SOURCES_CUDA})
38+ endif ()
39+ if (NOT GGML_CUDA_FORCE_CUBLAS)
40+ file (GLOB SRCS "template-instances/mmq*.cu" )
4641 list (APPEND GGML_SOURCES_CUDA ${SRCS} )
4742 endif ()
4843
44+ if (GGML_CUDA_FA)
45+ add_compile_definitions (GGML_CUDA_FA)
46+ if (GGML_CUDA_FA_ALL_QUANTS)
47+ file (GLOB SRCS "template-instances/fattn-vec*.cu" )
48+ list (APPEND GGML_SOURCES_CUDA ${SRCS} )
49+ add_compile_definitions (GGML_CUDA_FA_ALL_QUANTS)
50+ else ()
51+ file (GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu" )
52+ list (APPEND GGML_SOURCES_CUDA ${SRCS} )
53+ file (GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu" )
54+ list (APPEND GGML_SOURCES_CUDA ${SRCS} )
55+ file (GLOB SRCS "template-instances/fattn-vec*f16-f16.cu" )
56+ list (APPEND GGML_SOURCES_CUDA ${SRCS} )
57+ endif ()
58+ endif ()
59+
4960 ggml_add_backend_library(ggml-cuda
5061 ${GGML_HEADERS_CUDA}
5162 ${GGML_SOURCES_CUDA}
0 commit comments