Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ if (Vulkan_FOUND)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
endif()

# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)

if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_EXT_bfloat16 supported by glslc")
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
endif()

target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

Expand Down Expand Up @@ -142,6 +158,7 @@ if (Vulkan_FOUND)
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}
Expand Down
196 changes: 175 additions & 21 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
endif()
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
Expand Down
11 changes: 9 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ void main() {
// fast path for when all four iterations are in-bounds
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#ifndef OPTIMIZATION_ERROR_WORKAROUND

#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
Expand All @@ -31,7 +35,10 @@ void main() {
continue;
}

#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ void main() {
return;
}

#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
Expand Down
8 changes: 7 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
}
#endif

#if defined(DATA_A_BF16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
}
#endif

#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
Expand Down Expand Up @@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif

#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}
Expand Down
9 changes: 7 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@ void main() {
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;

#if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
data_d[d_offset + i00] = D_TYPE(v);
#else
data_d[d_offset + i00] = data_a[a_offset + i00];
data_d[d_offset + i00] = D_TYPE(v);
#endif
}
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
Expand Down
39 changes: 31 additions & 8 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif

#if defined(DATA_A_BF16) && defined(COOPMAT)
#extension GL_EXT_bfloat16 : enable
#endif

#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
Expand All @@ -29,6 +33,10 @@
#define LOAD_VEC_B 1
#endif

#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
#endif

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
Expand Down Expand Up @@ -202,8 +210,8 @@ void main() {
#endif

#ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];

[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
Expand Down Expand Up @@ -248,6 +256,21 @@ void main() {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
#else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
Expand Down Expand Up @@ -695,21 +718,21 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
Expand Down
37 changes: 23 additions & 14 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#ifdef DATA_A_BF16
#extension GL_EXT_bfloat16 : enable
#endif

#include "types.comp"

Expand Down Expand Up @@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#define store_scales(a)
#endif

#if defined(DATA_A_BF16)
#define MAT_TYPE bfloat16_t
#else
#define MAT_TYPE FLOAT_TYPE
#endif

#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};

Expand Down Expand Up @@ -271,8 +280,8 @@ void main() {

// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
Expand All @@ -286,8 +295,8 @@ void main() {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
Expand All @@ -310,8 +319,8 @@ void main() {

// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
Expand All @@ -325,8 +334,8 @@ void main() {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
Expand All @@ -350,8 +359,8 @@ void main() {

// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
Expand All @@ -365,8 +374,8 @@ void main() {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
Expand Down Expand Up @@ -405,8 +414,8 @@ void main() {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}

coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#version 460

#extension GL_EXT_bfloat16 : require

void main()
{
}
27 changes: 27 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.comp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@
#endif
#endif

#if defined(DATA_A_BF16)
#define QUANT_K 1
#define QUANT_R 1

#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE uint16_t
#elif LOAD_VEC_A == 4
#define A_TYPE u16vec4
#elif LOAD_VEC_A == 8
#error unsupported
#endif
#endif

#define QUANT_K_Q4_0 32
#define QUANT_R_Q4_0 2

Expand Down Expand Up @@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
}
#endif

// returns the bfloat value in the low 16b.
// See ggml_compute_fp32_to_bf16
uint32_t fp32_to_bf16(float f)
{
uint32_t u = floatBitsToUint(f);
u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
return u;
}

float bf16_to_fp32(uint32_t u)
{
return uintBitsToFloat(u << 16);
}

#endif // !defined(GGML_TYPES_COMP)
Loading
Loading