Skip to content

Commit 96308e1

Browse files
committed
require ggml_contiguous_rows in supports_op and expect nb00=1 in the shader
1 parent 8ec8ea4 commit 96308e1

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ struct vk_op_sum_rows_push_constants
10181018
{
10191019
uint32_t n_cols;
10201020
uint32_t ne01, ne02;
1021-
uint32_t nb00, nb01, nb02, nb03;
1021+
uint32_t nb01, nb02, nb03;
10221022
uint32_t nb11, nb12, nb13;
10231023
float weight;
10241024
uint32_t misalign_offsets;
@@ -1032,7 +1032,6 @@ vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tens
10321032
p.n_cols = (uint32_t)n_cols;
10331033
p.ne01 = (uint32_t)src->ne[1];
10341034
p.ne02 = (uint32_t)src->ne[2];
1035-
p.nb00 = (uint32_t)src->nb[0] / type_size;
10361035
p.nb01 = (uint32_t)src->nb[1] / type_size;
10371036
p.nb02 = (uint32_t)src->nb[2] / type_size;
10381037
p.nb03 = (uint32_t)src->nb[3] / type_size;
@@ -8590,7 +8589,6 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85908589

85918590
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
85928591
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
8593-
p.nb00 = 1; // treat src0 as flattened 1D tensor
85948592
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
85958593
}
85968594

@@ -11491,9 +11489,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1149111489
case GGML_OP_DIAG_MASK_INF:
1149211490
case GGML_OP_SOFT_MAX:
1149311491
case GGML_OP_SOFT_MAX_BACK:
11492+
return true;
1149411493
case GGML_OP_SUM:
1149511494
case GGML_OP_SUM_ROWS:
1149611495
case GGML_OP_MEAN:
11496+
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
1149711497
case GGML_OP_ARGMAX:
1149811498
case GGML_OP_COUNT_EQUAL:
1149911499
case GGML_OP_IM2COL:

ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ layout (push_constant) uniform parameter
1515
{
1616
uint n_cols;
1717
uint ne01, ne02;
18-
uint nb00, nb01, nb02, nb03;
18+
uint nb01, nb02, nb03;
1919
uint nb11, nb12, nb13;
2020
float weight;
2121
uint misalign_offsets;
@@ -53,7 +53,7 @@ void main() {
5353
tmp[col] = FLOAT_TYPE(0.0);
5454

5555
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
56-
tmp[col] += FLOAT_TYPE(data_a[src_idx + i * p.nb00]);
56+
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
5757
}
5858

5959
barrier();

0 commit comments

Comments
 (0)