12
12
13
13
import triton
14
14
import triton .language as tl
15
+ from triton .tools .tensor_descriptor import TensorDescriptor
15
16
16
- from tritonbench .utils .env_utils import is_hip_mi300
17
+ from tritonbench .utils .env_utils import is_cuda , is_fbcode , is_hip_mi300
17
18
18
19
from .triton_matmul_configs import get_full_amd_config_space
19
20
21
+ if not is_fbcode ():
22
+ if is_cuda ():
23
+ from triton ._C .libtriton import nvidia
24
+
25
+ cublas_workspace = torch .empty (
26
+ 32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8
27
+ )
28
+ cublas = nvidia .cublas .CublasLt (cublas_workspace )
29
+ else :
30
+ cublas = None
31
+
32
+
20
33
if os .environ .get ("FULL_AUTOTUNING_AMD" , "0" ) == "1" and torch .version .hip is not None :
21
34
tuning_configs = get_full_amd_config_space (False )
22
35
else :
56
69
}
57
70
)
58
71
@triton .jit
59
- def streamk_gemm (
72
+ def streamk_amd_gemm (
60
73
A ,
61
74
B ,
62
75
C ,
@@ -274,7 +287,7 @@ def streamk_gemm(
274
287
start_iter = end_iter
275
288
276
289
277
- def streamk_matmul (a , b , bias = None ):
290
+ def streamk_amd_matmul (a , b , bias = None ):
278
291
M , K = a .shape
279
292
_ , N = b .shape
280
293
dtype = a .dtype
@@ -350,7 +363,7 @@ def streamk_matmul(a, b, bias=None):
350
363
and c .stride (0 ) >= 0
351
364
and c .stride (1 ) >= 0
352
365
)
353
- streamk_gemm [(grids ,)](
366
+ streamk_amd_gemm [(grids ,)](
354
367
a ,
355
368
b ,
356
369
c ,
@@ -376,3 +389,253 @@ def streamk_matmul(a, b, bias=None):
376
389
# print(c)
377
390
# print(a @ b)
378
391
return c
392
+
393
+ def _matmul_launch_metadata (grid , kernel , args ):
394
+ ret = {}
395
+ M , N , K = args ["M" ], args ["N" ], args ["K" ]
396
+ ret ["name" ] = f"{ kernel .name } [M={ M } , N={ N } , K={ K } ]"
397
+ ret ["flops8" ] = 2.0 * M * N * K
398
+ if "c_ptr" in args :
399
+ bytes_per_elem = args ["c_ptr" ].element_size ()
400
+ else :
401
+ bytes_per_elem = 1 if args ["FP8_OUTPUT" ] else 2
402
+ ret ["bytes" ] = bytes_per_elem * (M * K + N * K )
403
+ return ret
404
+
405
+
406
+ def matmul_get_configs (pre_hook = None ):
407
+ return [
408
+ triton .Config (
409
+ {"BLOCK_M" : BM , "BLOCK_N" : BN , "BLOCK_K" : BK , "SK_BLOCK_K" : skBK , "GROUP_M" : 8 },
410
+ num_stages = s ,
411
+ num_warps = w ,
412
+ pre_hook = pre_hook ,
413
+ ) #
414
+ for BM in [128 , 256 ] #
415
+ for BN in [128 , 256 ] #
416
+ for BK in [32 , 64 , 128 ] #
417
+ for skBK in [16 , 32 , 64 , 128 ] #
418
+ for s in ([2 , 3 , 4 ]) #
419
+ for w in [4 , 8 ] #
420
+ ]
421
+
422
+ def matmul_tma_set_block_size_hook (nargs ):
423
+ BLOCK_M = nargs ["BLOCK_M" ]
424
+ BLOCK_N = nargs ["BLOCK_N" ]
425
+ BLOCK_K = nargs ["BLOCK_K" ]
426
+ nargs ["a_desc" ].block_shape = [BLOCK_M , BLOCK_K ]
427
+ nargs ["b_desc" ].block_shape = [BLOCK_N , BLOCK_K ]
428
+ nargs ["c_desc" ].block_shape = [BLOCK_M , BLOCK_N ]
429
+
430
+ SK_BLOCK_K = nargs ["SK_BLOCK_K" ]
431
+ nargs ["a_desc_sk" ].block_shape = [BLOCK_M , SK_BLOCK_K ]
432
+ nargs ["b_desc_sk" ].block_shape = [BLOCK_N , SK_BLOCK_K ]
433
+
434
+ @triton .autotune (
435
+ configs = matmul_get_configs (pre_hook = matmul_tma_set_block_size_hook ),
436
+ key = ["M" , "N" , "K" ],
437
+ )
438
+ @triton .jit (launch_metadata = _matmul_launch_metadata )
439
+ def streamk_cuda_gemm (
440
+ # Pointer to a [BLOCK_M, BLOCK_K] TensorDescriptor
441
+ a_desc ,
442
+ # Pointer to b [BLOCK_N, BLOCK_K] TensorDescriptor
443
+ b_desc ,
444
+ # Pointer to a [BLOCK_M, SK_BLOCK_K] TensorDescriptor
445
+ a_desc_sk ,
446
+ # Pointer to b [BLOCK_N, SK_BLOCK_K] TensorDescriptor
447
+ b_desc_sk ,
448
+ # Pointer to c [BLOCK_M, BLOCK_N] TensorDescriptor
449
+ c_desc ,
450
+ #
451
+ M ,
452
+ N ,
453
+ K ,
454
+ # Tile dimensions both phases
455
+ BLOCK_M : tl .constexpr ,
456
+ BLOCK_N : tl .constexpr ,
457
+ # K block dimension for DDP phase
458
+ BLOCK_K : tl .constexpr ,
459
+ # K block dimension for Stream-K phase
460
+ SK_BLOCK_K : tl .constexpr ,
461
+ # Group size for both phases
462
+ GROUP_M : tl .constexpr ,
463
+ # TRUE if lowering for FP8 output
464
+ FP8_OUTPUT : tl .constexpr ,
465
+ #
466
+ ENABLE_BUFFER_OPS_ASSUMES : tl .constexpr ,
467
+ # Number of SMs on the device
468
+ NUM_SMS : tl .constexpr ,
469
+ ):
470
+ if ENABLE_BUFFER_OPS_ASSUMES :
471
+ tl .assume (M >= 0 )
472
+ tl .assume (N >= 0 )
473
+ tl .assume (K >= 0 )
474
+
475
+ dtype = tl .float8e4nv if FP8_OUTPUT else tl .float16
476
+
477
+ pid = tl .program_id (0 )
478
+ num_pid = tl .num_programs (0 )
479
+ num_tile_m = tl .cdiv (M , BLOCK_M )
480
+ num_tile_n = tl .cdiv (N , BLOCK_N )
481
+ num_tile_in_group = GROUP_M * num_tile_n
482
+
483
+ total_tiles = num_tile_m * num_tile_n
484
+
485
+ # number of full waves
486
+ W = total_tiles // NUM_SMS
487
+ # number of tiles in partial wave
488
+ R = total_tiles % NUM_SMS
489
+ if W == 0 or R == 0 :
490
+ total_ddp_tiles = num_pid
491
+ streamk_sms = 0
492
+ else :
493
+ # hybrid Stream-K + DDP: DDP on first W-1 waves, Stream-K on last wave with full SM occupancy
494
+ total_ddp_tiles = num_pid - NUM_SMS
495
+ streamk_sms = NUM_SMS
496
+
497
+
498
+ # ----------------------------------------------------------------------------
499
+ # DDP phase
500
+ # ----------------------------------------------------------------------------
501
+ if pid < total_ddp_tiles :
502
+ # Each DDP-assigned program computes 1 full tile
503
+ group_id = pid // num_tile_in_group
504
+ first_tile_m = group_id * GROUP_M
505
+ group_size_m = min (num_tile_m - first_tile_m , GROUP_M )
506
+ tile_m = first_tile_m + (pid % group_size_m )
507
+ tile_n = (pid % num_tile_in_group ) // group_size_m
508
+
509
+ offs_am = tile_m * BLOCK_M
510
+ offs_bn = tile_n * BLOCK_N
511
+
512
+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
513
+
514
+ work_units_per_tile = tl .cdiv (K , BLOCK_K )
515
+
516
+ for k in tl .range (0 , work_units_per_tile , warp_specialize = True ):
517
+ offs_k = k * BLOCK_K
518
+ a = a_desc .load ([offs_am , offs_k ])
519
+ b = b_desc .load ([offs_bn , offs_k ])
520
+ accumulator = tl .dot (a , b .T , accumulator )
521
+
522
+ c = accumulator .to (dtype )
523
+ c_desc .store ([offs_am , offs_bn ], c )
524
+
525
+ # ----------------------------------------------------------------------------
526
+ # Stream-K phase
527
+ # ----------------------------------------------------------------------------
528
+ else :
529
+ # index each Stream-K program as if it were a single SM (num_pid - total_ddp_tiles = streamk_sms)
530
+ worker_id = pid - total_ddp_tiles
531
+
532
+ work_units_per_tile = tl .cdiv (K , SK_BLOCK_K )
533
+ total_work_units = (total_tiles - total_ddp_tiles ) * work_units_per_tile
534
+
535
+ # `evenly` distribute work units across SMs, with rem tiles assigned contiguously to the first rem programs
536
+ base = total_work_units // streamk_sms
537
+ rem = total_work_units % streamk_sms
538
+ work = tl .where (worker_id < rem , base + 1 , base )
539
+ start = tl .where (
540
+ worker_id < rem ,
541
+ worker_id * (base + 1 ),
542
+ rem * (base + 1 ) + (worker_id - rem ) * base
543
+ )
544
+ end = start + work - 1
545
+
546
+ # if start >= total_units, nothing to do
547
+ if start >= total_work_units :
548
+ return
549
+
550
+ # this program is responsible for computing tiles [(st_tile_streamk, en_k_streamk), (en_tile_streamk, en_k_streamk)]
551
+ # *_k_streamk indexes along the K dimension and is one of {0, 1, ..., work_units_per_tile - 1}
552
+ st_tile_streamk = start // work_units_per_tile + total_ddp_tiles
553
+ st_k_streamk = start % work_units_per_tile
554
+ en_tile_streamk = end // work_units_per_tile + total_ddp_tiles
555
+ en_k_streamk = end % work_units_per_tile
556
+
557
+ for curr_tile in tl .range (st_tile_streamk , en_tile_streamk + 1 , flatten = True ):
558
+ # Compute the tile associate with this work unit --- consistent with the DDP phase
559
+ group_id = curr_tile // num_tile_in_group
560
+ first_tile_m = group_id * GROUP_M
561
+ group_size_m = min (num_tile_m - first_tile_m , GROUP_M )
562
+ tile_m = first_tile_m + (curr_tile % group_size_m )
563
+ tile_n = (curr_tile % num_tile_in_group ) // group_size_m
564
+
565
+ offs_am = tile_m * BLOCK_M
566
+ offs_bn = tile_n * BLOCK_N
567
+
568
+ # compute the start and end K index on this tile for this work unit
569
+ curr_st_k = tl .where (curr_tile == st_tile_streamk , st_k_streamk , 0 )
570
+ curr_en_k = tl .where (curr_tile == en_tile_streamk , en_k_streamk , work_units_per_tile - 1 )
571
+
572
+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
573
+
574
+ for k in tl .range (curr_st_k , curr_en_k + 1 , warp_specialize = True ):
575
+ offs_k = k * SK_BLOCK_K
576
+ # if same Tensor Descriptor shape is used for both phases, just use DDP's (better performance)
577
+ if BLOCK_K == SK_BLOCK_K :
578
+ a = a_desc .load ([offs_am , offs_k ])
579
+ b = b_desc .load ([offs_bn , offs_k ])
580
+ else :
581
+ a = a_desc_sk .load ([offs_am , offs_k ])
582
+ b = b_desc_sk .load ([offs_bn , offs_k ])
583
+ accumulator = tl .dot (a , b .T , accumulator )
584
+
585
+ c = accumulator .to (dtype )
586
+
587
+ if curr_st_k == 0 and curr_en_k == work_units_per_tile - 1 :
588
+ c_desc .store ([offs_am , offs_bn ], c )
589
+ else :
590
+ # NOTE: known correctness issue with atomic_add
591
+ c_desc .atomic_add ([offs_am , offs_bn ], c )
592
+
593
+ def streamk_cuda_matmul (a , b ):
594
+ assert a .dtype == b .dtype , "Incompatible dtypes"
595
+
596
+ M , K = a .shape
597
+ N , K = b .shape
598
+ dtype = a .dtype
599
+
600
+ c = torch .zeros ((M , N ), device = a .device , dtype = dtype )
601
+
602
+ dummy_block = [1 , 1 ]
603
+ a_desc = TensorDescriptor (a , a .shape , a .stride (), dummy_block )
604
+ b_desc = TensorDescriptor (b , b .shape , b .stride (), dummy_block )
605
+ c_desc = TensorDescriptor (c , c .shape , c .stride (), dummy_block )
606
+
607
+ a_desc_sk = TensorDescriptor (a , a .shape , a .stride (), dummy_block )
608
+ b_desc_sk = TensorDescriptor (b , b .shape , b .stride (), dummy_block )
609
+
610
+ num_sms = torch .cuda .get_device_properties ("cuda" ).multi_processor_count
611
+
612
+ def grid (META ):
613
+ nonlocal a_desc , b_desc , c_desc
614
+ BLOCK_M = META ["BLOCK_M" ]
615
+ BLOCK_N = META ["BLOCK_N" ]
616
+ num_tiles = triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N )
617
+ W = num_tiles // num_sms
618
+ R = num_tiles % num_sms
619
+ if W == 0 or R == 0 :
620
+ total_ddp_tiles = num_tiles
621
+ streamk_sms = 0
622
+ else :
623
+ total_ddp_tiles = (W - 1 ) * num_sms
624
+ streamk_sms = num_sms
625
+ return (total_ddp_tiles + streamk_sms ,)
626
+
627
+
628
+ streamk_cuda_gemm [grid ](
629
+ a_desc ,
630
+ b_desc ,
631
+ a_desc_sk ,
632
+ b_desc_sk ,
633
+ c_desc , #
634
+ M ,
635
+ N ,
636
+ K , #
637
+ FP8_OUTPUT = dtype == torch .float8_e4m3fn , #
638
+ ENABLE_BUFFER_OPS_ASSUMES = True , #
639
+ NUM_SMS = num_sms #
640
+ )
641
+ return c
0 commit comments