Skip to content

Commit 989a348

Browse files
committed
cont : simplify data loading
1 parent c1fb380 commit 989a348

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7927,24 +7927,14 @@ kernel void kernel_mul_mm(
79277927

79287928
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
79297929
// load data and store to threadgroup memory
7930-
if (is_same<T0_4x4, block_q>::value) {
7930+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bounds_check) {
79317931
threadgroup_barrier(mem_flags::mem_threadgroup);
79327932

79337933
// no need for dequantization
7934-
if (FC_mul_mm_bounds_check) {
7935-
// bounds checks are required
7936-
for (short i = 0; i < 16; i++) {
7937-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7938-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7939-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
7940-
}
7941-
} else {
7942-
// do not perform bounds checks
7943-
FOR_UNROLL (short i = 0; i < 16; i++) {
7944-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7945-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7946-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = ((device T0 *) x)[i];
7947-
}
7934+
for (short i = 0; i < 16; i++) {
7935+
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7936+
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7937+
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
79487938
}
79497939
} else {
79507940
S0_4x4 temp_a;
@@ -8179,24 +8169,14 @@ kernel void kernel_mul_mm_id(
81798169

81808170
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
81818171
// load data and store to threadgroup memory
8182-
if (is_same<T0_4x4, block_q>::value) {
8172+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bounds_check) {
81838173
threadgroup_barrier(mem_flags::mem_threadgroup);
81848174

81858175
// no need for dequantization
8186-
if (FC_mul_mm_bounds_check) {
8187-
// bounds checks are required
8188-
for (short i = 0; i < 16; i++) {
8189-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8190-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8191-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
8192-
}
8193-
} else {
8194-
// do not perform bounds checks
8195-
FOR_UNROLL (short i = 0; i < 16; i++) {
8196-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8197-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8198-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = ((device T0 *) x)[i];
8199-
}
8176+
for (short i = 0; i < 16; i++) {
8177+
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8178+
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8179+
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
82008180
}
82018181
} else {
82028182
S0_4x4 temp_a;

0 commit comments

Comments
 (0)