@@ -422,3 +422,192 @@ def test_runtime_tensor_copy(BLOCK_M, BLOCK_N):
422422
423423 b_triton = b_device .cpu ()
424424 assert torch .equal (b_triton , a )
425+
426+
427+ @gluon .jit
428+ def mxgemm_kernel (a_ptr , b_ptr , c_ptr , a_scale , b_scale , M , N , K , stride_am , stride_ak , stride_bk , stride_bn , stride_cm ,
429+ stride_cn , stride_scale , DTYPE_A : ttgl .constexpr , DTYPE_B : ttgl .constexpr ,
430+ SCALE_BLOCK : ttgl .constexpr , BLOCK_M : ttgl .constexpr , BLOCK_N : ttgl .constexpr ,
431+ BLOCK_K : ttgl .constexpr , GROUP_SIZE_M : ttgl .constexpr ):
432+ DIV_FACTOR_A : ttgl .constexpr = 2 if DTYPE_A == "e2m1" else 1
433+ DIV_FACTOR_B : ttgl .constexpr = 2 if DTYPE_B == "e2m1" else 1
434+ BLOCK_K_SCALE : ttgl .constexpr = BLOCK_K // SCALE_BLOCK
435+ BLOCK_K_PACKED_A : ttgl .constexpr = BLOCK_K // DIV_FACTOR_A
436+ BLOCK_K_PACKED_B : ttgl .constexpr = BLOCK_K // DIV_FACTOR_B
437+
438+ BLOCKED_LAYOUT : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [8 , 4 ], [4 , 1 ], [1 , 0 ])
439+ A_BLOCKED_LAYOUT : ttgl .constexpr = ttgl .BlockedLayout ([1 , 16 ], [8 , 4 ], [4 , 1 ], [1 , 0 ])
440+ B_BLOCKED_LAYOUT : ttgl .constexpr = ttgl .BlockedLayout ([1 , 16 ], [16 , 2 ], [4 , 1 ], [1 , 0 ])
441+
442+ WMMA_LAYOUT : ttgl .constexpr = ttgl .amd .AMDWMMALayout (3 , transposed = True , warps_per_cta = [2 , 2 ],
443+ instr_shape = [16 , 16 , 128 ])
444+ WMMA_LAYOUT_PACKED : ttgl .constexpr = ttgl .amd .AMDWMMALayout (3 , transposed = True , warps_per_cta = [2 , 2 ],
445+ instr_shape = [16 , 16 , 64 ])
446+ A_SCALE_LINEAR_LAYOUT : ttgl .constexpr = ttgl .DistributedLinearLayout (
447+ reg_bases = [[0 , 1 ], [0 , 2 ]], lane_bases = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 0 ]], warp_bases = [[0 , 0 ], [16 , 0 ]],
448+ block_bases = [], shape = [32 , 4 ])
449+ B_SCALE_LINEAR_LAYOUT : ttgl .constexpr = ttgl .DistributedLinearLayout (
450+ reg_bases = [[0 , 1 ], [0 , 2 ]], lane_bases = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 0 ]], warp_bases = [[16 , 0 ], [0 , 0 ]],
451+ block_bases = [], shape = [32 , 4 ])
452+
453+ DOT_LAYOUT_A : ttgl .constexpr = ttgl .DotOperandLayout (
454+ operand_index = 0 , parent = WMMA_LAYOUT_PACKED if DTYPE_A == "e2m1" else WMMA_LAYOUT , k_width = 16 )
455+ DOT_LAYOUT_B : ttgl .constexpr = ttgl .DotOperandLayout (
456+ operand_index = 1 , parent = WMMA_LAYOUT_PACKED if DTYPE_B == "e2m1" else WMMA_LAYOUT , k_width = 16 )
457+
458+ pid = ttgl .program_id (axis = 0 )
459+ num_pid_m = ttgl .cdiv (M , BLOCK_M )
460+ num_pid_n = ttgl .cdiv (N , BLOCK_N )
461+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
462+ group_id = pid // num_pid_in_group
463+ first_pid_m = group_id * GROUP_SIZE_M
464+ group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
465+ pid_m = first_pid_m + (pid % group_size_m )
466+ pid_n = (pid % num_pid_in_group ) // group_size_m
467+
468+ offs_am = (pid_m * BLOCK_M + ttgl .arange (0 , BLOCK_M , layout = ttgl .SliceLayout (1 , A_BLOCKED_LAYOUT ))) % M
469+ offs_ak = ttgl .arange (0 , BLOCK_K_PACKED_A , layout = ttgl .SliceLayout (0 , A_BLOCKED_LAYOUT ))
470+ offs_bk = ttgl .arange (0 , BLOCK_K_PACKED_B , layout = ttgl .SliceLayout (1 , B_BLOCKED_LAYOUT ))
471+ offs_bn = (pid_n * BLOCK_N + ttgl .arange (0 , BLOCK_N , layout = ttgl .SliceLayout (0 , B_BLOCKED_LAYOUT ))) % N
472+
473+ offs_scale_am = (pid_m * BLOCK_M + ttgl .arange (0 , BLOCK_M , layout = ttgl .SliceLayout (1 , BLOCKED_LAYOUT ))) % M
474+ offs_scale_ak = ttgl .arange (0 , BLOCK_K_SCALE , layout = ttgl .SliceLayout (0 , BLOCKED_LAYOUT ))
475+ offs_scale_bn = (pid_n * BLOCK_N + ttgl .arange (0 , BLOCK_N , layout = ttgl .SliceLayout (1 , BLOCKED_LAYOUT ))) % N
476+ offs_scale_bk = ttgl .arange (0 , BLOCK_K_SCALE , layout = ttgl .SliceLayout (0 , BLOCKED_LAYOUT ))
477+
478+ a_scale_ptr = a_scale + offs_scale_am [:, None ] * stride_scale + offs_scale_ak [None , :]
479+ b_scale_ptr = b_scale + offs_scale_bn [:, None ] * stride_scale + offs_scale_bk [None , :]
480+ a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_ak [None , :] * stride_ak )
481+ b_ptrs = b_ptr + (offs_bk [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
482+
483+ accumulator = ttgl .zeros ((BLOCK_M , BLOCK_N ), dtype = ttgl .float32 , layout = WMMA_LAYOUT )
484+ for k in range (0 , ttgl .cdiv (K , BLOCK_K )):
485+ k_remaining_a = K - k * BLOCK_K_PACKED_A
486+ k_remaining_b = K - k * BLOCK_K_PACKED_B
487+ valid_k_a = offs_ak < k_remaining_a
488+ valid_k_b = offs_bk < k_remaining_b
489+
490+ scale_a = ttgl .load (a_scale_ptr )
491+ scale_b = ttgl .load (b_scale_ptr )
492+ scale_a = ttgl .convert_layout (scale_a , A_SCALE_LINEAR_LAYOUT )
493+ scale_b = ttgl .convert_layout (scale_b , B_SCALE_LINEAR_LAYOUT )
494+
495+ a = ttgl .load (a_ptrs , mask = valid_k_a [None , :], other = 0.0 )
496+ b = ttgl .load (b_ptrs , mask = valid_k_b [:, None ], other = 0.0 )
497+ a = ttgl .convert_layout (a , DOT_LAYOUT_A )
498+ b = ttgl .convert_layout (b , DOT_LAYOUT_B )
499+
500+ accumulator = ttgl .amd .gfx1250 .wmma_scaled (a , scale_a , DTYPE_A , b , scale_b , DTYPE_B , accumulator )
501+
502+ a_ptrs += BLOCK_K_PACKED_A * stride_ak
503+ b_ptrs += BLOCK_K_PACKED_B * stride_bk
504+
505+ a_scale_ptr += BLOCK_K_SCALE
506+ b_scale_ptr += BLOCK_K_SCALE
507+
508+ offs_cm = pid_m * BLOCK_M + ttgl .arange (0 , BLOCK_M , layout = ttgl .SliceLayout (1 , WMMA_LAYOUT ))
509+ offs_cn = pid_n * BLOCK_N + ttgl .arange (0 , BLOCK_N , layout = ttgl .SliceLayout (0 , WMMA_LAYOUT ))
510+ c_ptrs = c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
511+ c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
512+ ttgl .store (c_ptrs , accumulator , mask = c_mask )
513+
514+
515+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(32 , 32 , 64 ), (32 , 32 , 128 )])
516+ @pytest .mark .parametrize ("DTYPE_A" , ["float8_e5m2" , "float8_e4m3" , "float4" ])
517+ @pytest .mark .parametrize ("DTYPE_B" , ["float8_e5m2" , "float8_e4m3" , "float4" ])
518+ def test_compile_mxgemm (BLOCK_M , BLOCK_N , BLOCK_K , DTYPE_A , DTYPE_B ):
519+ scale_block = 32
520+
521+ if BLOCK_K < 128 :
522+ pytest .skip ("NYI: don't support block shape smaller than instr shape" )
523+
524+ triton_dtype_converter = {'float8_e5m2' : "fp8e5" , "float8_e4m3" : "fp8e4nv" , "float4" : "u8" }
525+ dot_scaled_dtype_converter = {'float8_e5m2' : "e5m2" , "float8_e4m3" : "e4m3" , "float4" : "e2m1" }
526+
527+ k = triton .compile (
528+ gluon ._runtime .GluonASTSource (
529+ fn = mxgemm_kernel , signature = {
530+ "a_ptr" : f"*{ triton_dtype_converter [DTYPE_A ]} " , "b_ptr" : f"*{ triton_dtype_converter [DTYPE_B ]} " , "c_ptr" :
531+ "*fp32" , "a_scale" : "*u8" , "b_scale" : "*u8" , "M" : "i32" , "N" : "i32" , "K" : "i32" , "stride_am" : "i32" ,
532+ "stride_ak" : "i32" , "stride_bk" : "i32" , "stride_bn" : "i32" , "stride_cm" : "i32" , "stride_cn" : "i32" ,
533+ "stride_scale" : "i32" , "DTYPE_A" : "constexpr" , "DTYPE_B" : "constexpr" , "SCALE_BLOCK" : "constexpr" ,
534+ "BLOCK_M" : "constexpr" , "BLOCK_N" : "constexpr" , "BLOCK_K" : "constexpr" , "GROUP_SIZE_M" : "constexpr"
535+ }, constexprs = {
536+ "DTYPE_A" : dot_scaled_dtype_converter [DTYPE_A ], "DTYPE_B" : dot_scaled_dtype_converter [DTYPE_B ],
537+ "SCALE_BLOCK" : scale_block , "BLOCK_M" : BLOCK_M , "BLOCK_N" : BLOCK_N , "BLOCK_K" : BLOCK_K , "GROUP_SIZE_M" :
538+ 1
539+ }), target = GPUTarget ("hip" , 'gfx1250' , 32 ))
540+
541+ amdgcn = k .asm ["amdgcn" ]
542+ pattern = "v_wmma_scale_f32_16x16x128_f8f6f4"
543+ assert re .search (pattern , amdgcn ), f"Can't find instruction { pattern } in AMDGCN assembly"
544+
545+
546+ @pytest .mark .parametrize ("M, N, K" , [(32 , 32 , 128 ), (128 , 128 , 512 ), (1 , 8192 , 512 )])
547+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(32 , 32 , 128 ), (64 , 64 , 128 )])
548+ @pytest .mark .parametrize ("DTYPE_A" , ["float8_e5m2" , "float8_e4m3" , "float4" ])
549+ @pytest .mark .parametrize ("DTYPE_B" , ["float8_e5m2" , "float8_e4m3" , "float4" ])
550+ def test_runtime_mxgemm (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , DTYPE_A , DTYPE_B ):
551+ scale_block = 32
552+
553+ torch .manual_seed (0 )
554+
555+ def torch_gemm_mxfp (a , b , a_scale , b_scale , scale_block , M , N , K ):
556+ a_scale_f32 = a_scale .to (torch .float32 ).repeat_interleave (scale_block , dim = 1 )[:M , :K ]
557+ b_scale_f32 = b_scale .to (torch .float32 ).repeat_interleave (scale_block , dim = 1 ).T .contiguous ()[:K , :N ]
558+
559+ a_f32 = a .to (torch .float32 )
560+ b_f32 = b .to (torch .float32 )
561+
562+ return torch .matmul (a_f32 * a_scale_f32 , b_f32 * b_scale_f32 ).to (torch .float32 )
563+
564+ def init_data (dtype , d0 : int , d1 : int ):
565+ if dtype == 'float4' :
566+ return MXFP4Tensor (size = (d0 , d1 )).random ()
567+ elif dtype == "float8_e5m2" :
568+ return torch .randint (20 , 40 , (d0 , d1 ), dtype = torch .uint8 ).view (torch .float8_e5m2 )
569+ elif dtype == "float8_e4m3" :
570+ return torch .randint (20 , 40 , (d0 , d1 ), dtype = torch .uint8 ).view (torch .float8_e4m3fn )
571+ else :
572+ raise NotImplementedError (f"NYI: unsupported dtype: { dtype } " )
573+
574+ a = init_data (DTYPE_A , M , K )
575+ b = init_data (DTYPE_B , K , N )
576+ a_size = (M , (K + scale_block - 1 ) // scale_block )
577+ b_size = (N , (K + scale_block - 1 ) // scale_block )
578+ a_scale = MXScaleTensor (size = a_size ).random (low = 1.0 , high = 32.0 )
579+ b_scale = MXScaleTensor (size = b_size ).random (low = 1.0 , high = 32.0 )
580+
581+ c_ref = torch_gemm_mxfp (a , b , a_scale , b_scale , scale_block , M , N , K )
582+
583+ a_scale = a_scale .data
584+ b_scale = b_scale .data
585+
586+ # mxfp4 input needs packed along the k dim, i.e., two mxfp4 are packed in one uint8
587+ if DTYPE_A in ['float4' , 'float6_e2m3' , 'float6_e3m2' ]:
588+ a = a .to_packed_tensor (dim = 1 )
589+ if DTYPE_B in ['float4' , 'float6_e2m3' , 'float6_e3m2' ]:
590+ b = b .to_packed_tensor (dim = 0 )
591+
592+ c_d = torch .zeros (M , N , dtype = torch .float32 ).cuda ()
593+ a_d = a .data .contiguous ().cuda ()
594+ b_d = b .data .contiguous ().cuda ()
595+ a_scale_d = a_scale .cuda ()
596+ b_scale_d = b_scale .cuda ()
597+
598+ stride_am , stride_ak = a_d .stride (0 ), a_d .stride (1 )
599+ stride_bk , stride_bn = b_d .stride (0 ), b_d .stride (1 )
600+ stride_cm , stride_cn = c_d .stride (0 ), c_d .stride (1 )
601+ stride_scale = a_scale_d .stride (0 )
602+
603+ numBlocks = triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N )
604+ grid = [numBlocks , 1 , 1 ]
605+ group_size_m = 1
606+
607+ dtype_converter = {'float8_e5m2' : "e5m2" , "float8_e4m3" : "e4m3" , "float4" : "e2m1" }
608+
609+ mxgemm_kernel [grid ](a_d , b_d , c_d , a_scale_d , b_scale_d , M , N , K , stride_am , stride_ak , stride_bk , stride_bn ,
610+ stride_cm , stride_cn , stride_scale , dtype_converter [DTYPE_A ], dtype_converter [DTYPE_B ],
611+ scale_block , BLOCK_M , BLOCK_N , BLOCK_K , group_size_m , num_warps = 4 , num_ctas = 1 )
612+
613+ torch .testing .assert_close (c_d .cpu (), c_ref .cpu (), rtol = 1e-5 , atol = 1e-8 )
0 commit comments