Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7851,6 +7851,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
Expand Down
29 changes: 19 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint i00 = gl_GlobalInvocationID.x;
const uint i10 = gl_GlobalInvocationID.y;
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;

if (i00 >= p.ne00) {
return;
}

const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;

const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];

const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;

#if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(v);
data_d[d_offset + i00] = D_TYPE(v);
#else
data_d[d_offset + i00] = D_TYPE(v);
data_d[d_offset + i00] = D_TYPE(v);
#endif
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
}
40 changes: 25 additions & 15 deletions ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
const uint i10 = gl_GlobalInvocationID.y;
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;

#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
Expand All @@ -22,20 +19,33 @@ void main() {
return;
}

const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;

const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];

const uint ib = a_offset + i00/QUANT_K; // block index
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;

vec2 v = dequantize(ib, iqs, 0);
const vec2 dm = get_dm(ib, 0);
v = v * dm.x + dm.y;
const uint ib = a_offset + i00/QUANT_K; // block index
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
vec2 v = dequantize(ib, iqs, 0);
const vec2 dm = get_dm(ib, 0);
v = v * dm.x + dm.y;

data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);

gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
}
Loading