Skip to content

Commit 3ce7b72

Browse files
committed
wip
1 parent c39665f commit 3ce7b72

File tree

3 files changed

+56
-54
lines changed

3 files changed

+56
-54
lines changed

ggml/src/ggml-metal.m

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,19 +1015,19 @@ static void ggml_metal_encode_node(
10151015
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
10161016
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
10171017

1018-
//GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019-
//if (src0) {
1020-
// GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1021-
// ggml_is_contiguous(src0), src0->name);
1022-
//}
1023-
//if (src1) {
1024-
// GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1025-
// ggml_is_contiguous(src1), src1->name);
1026-
//}
1027-
//if (dst) {
1028-
// GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1029-
// dst->name);
1030-
//}
1018+
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019+
if (src0) {
1020+
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1021+
ggml_is_contiguous(src0), src0->name);
1022+
}
1023+
if (src1) {
1024+
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1025+
ggml_is_contiguous(src1), src1->name);
1026+
}
1027+
if (dst) {
1028+
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1029+
dst->name);
1030+
}
10311031

10321032
id<MTLDevice> device = ctx_dev->mtl_device;
10331033

@@ -1986,16 +1986,18 @@ static void ggml_metal_encode_node(
19861986
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
19871987
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
19881988
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1989-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1990-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1991-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1992-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1993-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1994-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1995-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1996-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1997-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1998-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1989+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1990+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1991+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1992+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1993+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
1994+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
1995+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
1996+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
1997+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1998+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1999+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
2000+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
19992001

20002002
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
20012003
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||

ggml/src/ggml-metal.metal

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,7 @@ void kernel_mul_mv_impl(
14631463
break;
14641464
}
14651465

1466-
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
1466+
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + i12*nb12 + i13*(ne12*nb12));
14671467

14681468
float sumf = 0;
14691469
for (int i = tiisg; i < ne00; i += 32) {
@@ -1483,7 +1483,7 @@ void kernel_mul_mv_impl(
14831483
break;
14841484
}
14851485

1486-
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
1486+
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + i12*nb12 + i13*(ne12*nb12));
14871487
device const T14 * y4 = (device const T14 *) y;
14881488

14891489
float sumf = 0;

tests/test-backend-ops.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,21 +3507,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
35073507
for (ggml_type type_a : base_types) {
35083508
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
35093509
// test cases without permutation
3510-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3511-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3512-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3513-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3514-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3515-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3516-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
3517-
3518-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
3519-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
3520-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
3521-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
3522-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
3523-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
3524-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
3510+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3511+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3512+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3513+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3514+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3515+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3516+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
3517+
3518+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
3519+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
3520+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
3521+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
3522+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
3523+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
3524+
//test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
35253525

35263526
// test cases with permutation
35273527
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
@@ -3537,14 +3537,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
35373537
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
35383538
}
35393539
}
3540-
for (ggml_type type_a : other_types) {
3541-
for (ggml_type type_b : {GGML_TYPE_F32}) {
3542-
if (ggml_blck_size(type_a) != 256) {
3543-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1}));
3544-
}
3545-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
3546-
}
3547-
}
3540+
//for (ggml_type type_a : other_types) {
3541+
// for (ggml_type type_b : {GGML_TYPE_F32}) {
3542+
// if (ggml_blck_size(type_a) != 256) {
3543+
// test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1}));
3544+
// }
3545+
// test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
3546+
// }
3547+
//}
35483548
#else
35493549
// m = a rows
35503550
// n = b rows
@@ -3564,12 +3564,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
35643564
}
35653565
#endif
35663566

3567-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
3568-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
3569-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
3570-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
3571-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
3572-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
3567+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
3568+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
3569+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
3570+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
3571+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
3572+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
35733573

35743574
// sycl backend will limit task global_range < MAX_INT
35753575
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)

0 commit comments

Comments
 (0)