Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added ggml/src/ggml-vulkan/.CMakeLists.txt.swp
Binary file not shown.
13 changes: 13 additions & 0 deletions ggml/src/ggml-vulkan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ if (Vulkan_FOUND)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
endif()

# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
else()
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
endif()

target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

Expand Down
15 changes: 14 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#undef CREATE_MM2
} else
#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat_support) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
Expand Down Expand Up @@ -1739,7 +1740,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#undef CREATE_MM2
#undef CREATE_MM
} else if (device->fp16) {
} else
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
Expand Down Expand Up @@ -2242,6 +2245,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
}

#if defined(VK_KHR_cooperative_matrix)
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
coopmat_features.pNext = nullptr;
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
Expand All @@ -2251,6 +2255,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
last_struct = (VkBaseOutStructure *)&coopmat_features;
}
#endif

#if defined(VK_NV_cooperative_matrix2)
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
Expand Down Expand Up @@ -2283,7 +2288,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_EXT_subgroup_size_control");
}

#if defined(VK_KHR_cooperative_matrix)
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
#endif

if (coopmat2_support) {
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
Expand Down Expand Up @@ -2376,6 +2383,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_shader_float16_int8");
}

#if defined(VK_KHR_cooperative_matrix)
if (device->coopmat_support) {
// Query supported shapes
std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
Expand Down Expand Up @@ -2442,6 +2450,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (device->coopmat_support) {
device_extensions.push_back("VK_KHR_cooperative_matrix");
}
#endif

device->name = GGML_VK_NAME + std::to_string(idx);

Expand Down Expand Up @@ -2554,9 +2563,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
fp16_storage = true;
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
fp16_compute = true;
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT")) {
coopmat_support = true;
#endif
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
Expand Down Expand Up @@ -2596,6 +2607,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
// Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;

#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
coopmat_features.pNext = nullptr;
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
Expand All @@ -2611,6 +2623,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
fp16 = fp16 && vk12_features.shaderFloat16;

coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
#endif

std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";

Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#version 460
#extension GL_KHR_cooperative_matrix : require
void main()
{
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,11 @@ void process_shaders() {
matmul_shaders(true, matmul_id, false, false, false);
matmul_shaders(true, matmul_id, false, false, true);

#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id, true, false, false);
matmul_shaders(true, matmul_id, true, false, true);
#endif

#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// Coopmat2, fp32acc and fp16acc
Expand Down