Skip to content

Commit 168a4b2

Browse files
committed
vulkan: fix batched matmul dequant for Q*_K
1 parent 69b7db8 commit 168a4b2

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint i = gl_WorkGroupID.x * 256 + wgy;
13-
if (i >= p.M * p.K / QUANT_K) {
13+
if (i >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
13-
if (i >= p.M * p.K / QUANT_K) {
13+
if (i >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint ib = gl_WorkGroupID.x * 256 + wgy;
13-
if (ib >= p.M * p.K / QUANT_K) {
13+
if (ib >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint ib = gl_WorkGroupID.x * 256 + wgy;
13-
if (ib >= p.M * p.K / QUANT_K) {
13+
if (ib >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint i = gl_WorkGroupID.x * 256 + wgy;
13-
if (i >= p.M * p.K / QUANT_K) {
13+
if (i >= p.nel / QUANT_K) {
1414
return;
1515
}
1616
const uint tid = gl_LocalInvocationID.x;

0 commit comments

Comments
 (0)