Skip to content

Commit 2e2b22b

Browse files
authored
vulkan: Add missing bounds checking to scalar/coopmat1 mul_mat_id (ggml-org#15334)
1 parent 912ff8c commit 2e2b22b

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ void main() {
801801
}
802802
#else
803803
const uint row_i = ic * BN + loadc_b + l;
804-
if (row_i < _ne1) {
804+
if (row_i < _ne1 && block + loadr_b < end_k) {
805805
const u16vec2 row_idx = row_ids[row_i];
806806
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
807807
} else {
@@ -875,7 +875,9 @@ void main() {
875875

876876
const u16vec2 row_idx = row_ids[row_i];
877877

878-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
878+
if (dr + cm_row * TM + store_r < p.M) {
879+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
880+
}
879881
}
880882
}
881883
}
@@ -925,7 +927,9 @@ void main() {
925927
#endif // MUL_MAT_ID
926928
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
927929
#ifdef MUL_MAT_ID
928-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
930+
if (dr_warp + cr < p.M) {
931+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
932+
}
929933
#else
930934
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
931935
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5824,6 +5824,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
58245824
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
58255825
}
58265826

5827+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
5828+
58275829
for (ggml_type type_a : base_types) {
58285830
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
58295831
for (int n_mats : {4, 8}) {

0 commit comments

Comments
 (0)