@@ -145,6 +145,28 @@ inline float blockReduceSum(
145
145
return sharedScratch[0 ];
146
146
}
147
147
148
+ template <bool col_major>
149
+ inline device float & get_ref (device float * A, uint row, uint col, uint N);
150
+
151
+ template <>
152
+ inline device float & get_ref<true >(
153
+ device float * A,
154
+ uint row,
155
+ uint col,
156
+ uint N) {
157
+ return A[row * N + col];
158
+ }
159
+
160
+ template <>
161
+ inline device float & get_ref<false >(
162
+ device float * A,
163
+ uint row,
164
+ uint col,
165
+ uint N) {
166
+ return A[row + col * N];
167
+ }
168
+
169
+ template <bool upper>
148
170
kernel void factorDiagonalBlock (
149
171
device float * A [[buffer(0 )]],
150
172
device int* info [[buffer(1 )]],
@@ -171,7 +193,7 @@ kernel void factorDiagonalBlock(
171
193
for (uint i = linear_tid; i < tileSize; i += group_size) {
172
194
uint r = i / actSize;
173
195
uint c = i % actSize;
174
- tile[r][c] = A[batch_offset + ( row0 + r) * N + ( col0 + c)] ;
196
+ tile[r][c] = get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N) ;
175
197
}
176
198
threadgroup_barrier (mem_flags::mem_threadgroup);
177
199
@@ -244,10 +266,33 @@ kernel void factorDiagonalBlock(
244
266
for (uint i = linear_tid; i < tileSize; i += group_size) {
245
267
uint r = i / actSize;
246
268
uint c = i % actSize;
247
- A[batch_offset + ( row0 + r) * N + ( col0 + c)] = tile[r][c];
269
+ get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N) = tile[r][c];
248
270
}
249
271
}
250
272
273
+ template [[host_name(" factorDiagonalBlockU" )]]
274
+ kernel void factorDiagonalBlock<true >(
275
+ device float * A [[buffer(0 )]],
276
+ device int * info [[buffer(1 )]],
277
+ constant uint& N [[buffer(2 )]],
278
+ constant uint& NB [[buffer(3 )]],
279
+ constant uint& k [[buffer(4 )]],
280
+ uint3 tid [[thread_position_in_threadgroup]],
281
+ uint3 bid [[threadgroup_position_in_grid]],
282
+ uint3 tpg [[threads_per_threadgroup]]);
283
+
284
+ template [[host_name(" factorDiagonalBlockL" )]]
285
+ kernel void factorDiagonalBlock<false >(
286
+ device float * A [[buffer(0 )]],
287
+ device int * info [[buffer(1 )]],
288
+ constant uint& N [[buffer(2 )]],
289
+ constant uint& NB [[buffer(3 )]],
290
+ constant uint& k [[buffer(4 )]],
291
+ uint3 tid [[thread_position_in_threadgroup]],
292
+ uint3 bid [[threadgroup_position_in_grid]],
293
+ uint3 tpg [[threads_per_threadgroup]]);
294
+
295
+ template <bool upper>
251
296
kernel void applyTRSM (
252
297
device float * A [[buffer(0 )]],
253
298
constant uint& N [[buffer(2 )]],
@@ -283,12 +328,12 @@ kernel void applyTRSM(
283
328
for (uint i = linear_tid; i < actSize_k * actSize_k; i += group_size) {
284
329
uint r = i / actSize_k;
285
330
uint c = i % actSize_k;
286
- diag[i] = A[batch_offset + ( k * NB + r) * N + ( k * NB + c)] ;
331
+ diag[i] = get_ref<upper>(A + batch_offset, k * NB + r, k * NB + c, N) ;
287
332
}
288
333
for (uint i = linear_tid; i < actSize_j * actSize_k; i += group_size) {
289
334
uint r = i / actSize_k;
290
335
uint c = i % actSize_k;
291
- target[i] = A[batch_offset + ( row0 + r) * N + ( col0 + c)] ;
336
+ target[i] = get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N) ;
292
337
}
293
338
threadgroup_barrier (mem_flags::mem_threadgroup);
294
339
@@ -332,10 +377,31 @@ kernel void applyTRSM(
332
377
for (uint i = linear_tid; i < actSize_j * actSize_k; i += group_size) {
333
378
uint r = i / actSize_k;
334
379
uint c = i % actSize_k;
335
- A[batch_offset + ( row0 + r) * N + ( col0 + c)] = target[i];
380
+ get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N) = target[i];
336
381
}
337
382
}
338
383
384
+ template [[host_name(" applyTRSMU" )]]
385
+ kernel void applyTRSM<true >(
386
+ device float * A [[buffer(0 )]],
387
+ constant uint& N [[buffer(2 )]],
388
+ constant uint& NB [[buffer(3 )]],
389
+ constant uint& k [[buffer(4 )]],
390
+ uint3 tid [[thread_position_in_threadgroup]],
391
+ uint3 tgid [[threadgroup_position_in_grid]],
392
+ uint3 tpg [[threads_per_threadgroup]]);
393
+
394
+ template [[host_name(" applyTRSML" )]]
395
+ kernel void applyTRSM<false >(
396
+ device float * A [[buffer(0 )]],
397
+ constant uint& N [[buffer(2 )]],
398
+ constant uint& NB [[buffer(3 )]],
399
+ constant uint& k [[buffer(4 )]],
400
+ uint3 tid [[thread_position_in_threadgroup]],
401
+ uint3 tgid [[threadgroup_position_in_grid]],
402
+ uint3 tpg [[threads_per_threadgroup]]);
403
+
404
+ template <bool upper>
339
405
kernel void applySYRK (
340
406
device float * A [[buffer(0 )]],
341
407
constant uint& N [[buffer(2 )]],
@@ -403,25 +469,37 @@ kernel void applySYRK(
403
469
// Same logic to load/store Cfrag, Afrag, Bfrag...
404
470
simdgroup_matrix<float , 8 , 8 > Cfrag;
405
471
simdgroup_load (
406
- Cfrag, &A[batch_offset + (row0 + sb_y) * N + (col0 + sb_x)], N);
472
+ Cfrag,
473
+ &get_ref<upper>(A + batch_offset, row0 + sb_y, col0 + sb_x, N),
474
+ N,
475
+ 0 ,
476
+ !upper);
407
477
408
478
for (uint kk = 0 ; kk < actSize_k; kk += 8 ) {
409
479
simdgroup_load (
410
- Afrag, &A[batch_offset + (row0 + sb_y) * N + (k * NB + kk)], N);
480
+ Afrag,
481
+ &get_ref<upper>(A + batch_offset, row0 + sb_y, k * NB + kk, N),
482
+ N,
483
+ 0 ,
484
+ !upper);
411
485
simdgroup_load (
412
486
Bfrag,
413
- &A[batch_offset + ( col0 + sb_x) * N + ( k * NB + kk)] ,
487
+ &get_ref<upper>(A + batch_offset, col0 + sb_x, k * NB + kk, N) ,
414
488
N,
415
489
/* matrix_origin = */ 0 ,
416
- /* transpose = */ true );
490
+ /* transpose = */ upper );
417
491
418
492
simdgroup_multiply (Prod, Afrag, Bfrag);
419
493
simdgroup_multiply (Prod, Prod, negative_identity);
420
494
simdgroup_multiply_accumulate (Cfrag, Cfrag, identity, Prod);
421
495
}
422
496
423
497
simdgroup_store (
424
- Cfrag, &A[batch_offset + (row0 + sb_y) * N + (col0 + sb_x)], N);
498
+ Cfrag,
499
+ &get_ref<upper>(A + batch_offset, row0 + sb_y, col0 + sb_x, N),
500
+ N,
501
+ 0 ,
502
+ !upper);
425
503
}
426
504
} else {
427
505
// Fallback for non-multiple-of-8 dimensions
@@ -442,8 +520,10 @@ kernel void applySYRK(
442
520
443
521
float sum = 0 .0f ;
444
522
for (uint i = 0 ; i < actSize_k; i++) {
445
- float a_val = A[batch_offset + (row0 + y) * N + k * NB + i];
446
- float b_val = A[batch_offset + (col0 + x) * N + k * NB + i];
523
+ float a_val =
524
+ get_ref<upper>(A + batch_offset, row0 + y, k * NB + i, N);
525
+ float b_val =
526
+ get_ref<upper>(A + batch_offset, col0 + x, k * NB + i, N);
447
527
sum = fma (a_val, b_val, sum);
448
528
}
449
529
sum_accumulator[y * tpg.x + x] += sum;
@@ -452,13 +532,35 @@ kernel void applySYRK(
452
532
threadgroup_barrier (mem_flags::mem_threadgroup);
453
533
for (uint y = ty; y < actSize_j; y += tpg.y ) {
454
534
for (uint x = tx; x < actSize_h; x += tpg.x ) {
455
- A[batch_offset + ( row0 + y) * N + col0 + x] -=
535
+ get_ref<upper>(A + batch_offset, row0 + y, col0 + x, N) -=
456
536
sum_accumulator[y * tpg.x + x];
457
537
}
458
538
}
459
539
}
460
540
}
461
541
542
+ template [[host_name(" applySYRKU" )]]
543
+ kernel void applySYRK<true >(
544
+ device float * A [[buffer(0 )]],
545
+ constant uint& N [[buffer(2 )]],
546
+ constant uint& NB [[buffer(3 )]],
547
+ constant uint& k [[buffer(4 )]],
548
+ uint3 tid [[thread_position_in_threadgroup]],
549
+ uint3 tgid [[threadgroup_position_in_grid]],
550
+ uint3 tpg [[threads_per_threadgroup]],
551
+ uint sgitg [[simdgroup_index_in_threadgroup]]);
552
+
553
+ template [[host_name(" applySYRKL" )]]
554
+ kernel void applySYRK<false >(
555
+ device float * A [[buffer(0 )]],
556
+ constant uint& N [[buffer(2 )]],
557
+ constant uint& NB [[buffer(3 )]],
558
+ constant uint& k [[buffer(4 )]],
559
+ uint3 tid [[thread_position_in_threadgroup]],
560
+ uint3 tgid [[threadgroup_position_in_grid]],
561
+ uint3 tpg [[threads_per_threadgroup]],
562
+ uint sgitg [[simdgroup_index_in_threadgroup]]);
563
+
462
564
kernel void applyPivots (
463
565
device float * P [[buffer(0 )]],
464
566
device const int* pivots [[buffer(1 )]],
0 commit comments