@@ -917,3 +917,108 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
917917 tt.return
918918 }
919919}
920+
921+ // -----
922+
923+ // COMMON-LABEL: bypass_lds_b_operand
924+
925+ // SYNC: scf.for
926+ // SYNC: %[[load:.+]] = tt.load {{.*}} : tensor<8x2048x!tt.ptr<i8>, #linear>
927+ // SYNC: %[[reshape1:.+]] = tt.reshape %arg24
928+ // SYNC: %[[trans1:.+]] = tt.trans %[[reshape1]]
929+ // SYNC: %[[reshape2:.+]] = tt.reshape %[[trans1]]
930+ // SYNC: %[[trans2:.+]] = tt.trans %[[reshape2]] {{.*}} -> tensor<128x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
931+ // SYNC: tt.dot_scaled {{.*}}, %[[trans2]]
932+ // SYNC: scf.yield {{.*}}, %[[load]]
933+
934+
935+ #blocked = #ttg.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
936+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
937+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
938+ #blocked3 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
939+ #linear = #ttg.linear <{register = [[0 , 2 ], [0 , 1 ]], lane = [[0 , 4 ], [0 , 8 ], [0 , 16 ], [0 , 32 ], [0 , 64 ], [0 , 128 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
940+ #linear1 = #ttg.linear <{register = [[0 , 0 , 0 , 0 , 1 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 , 1 , 0 ]], lane = [[0 , 0 , 0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 2 , 0 , 0 , 0 ], [0 , 0 , 0 , 4 , 0 , 0 , 0 ], [0 , 0 , 0 , 8 , 0 , 0 , 0 ], [0 , 0 , 1 , 0 , 0 , 0 , 0 ], [0 , 0 , 2 , 0 , 0 , 0 , 0 ]], warp = [[0 , 0 , 0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 , 0 , 0 ]], block = []}>
941+ #linear2 = #ttg.linear <{register = [[0 , 0 , 0 , 0 , 1 , 0 , 0 ], [0 , 1 , 0 , 0 , 0 , 0 , 0 ]], lane = [[0 , 0 , 1 , 0 , 0 , 0 , 0 ], [0 , 0 , 2 , 0 , 0 , 0 , 0 ], [0 , 0 , 4 , 0 , 0 , 0 , 0 ], [0 , 0 , 8 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 , 1 , 0 ], [0 , 0 , 0 , 0 , 0 , 2 , 0 ]], warp = [[0 , 0 , 0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 , 0 , 0 ]], block = []}>
942+ #linear3 = #ttg.linear <{register = [[0 , 4 ], [16 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 1 ], [0 , 2 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
943+ #linear4 = #ttg.linear <{register = [[0 , 2 ], [0 , 1 ]], lane = [[0 , 4 ], [0 , 8 ], [0 , 16 ], [0 , 32 ], [0 , 64 ], [0 , 128 ]], warp = [[1 , 0 ], [2 , 0 ]], block = []}>
944+ #linear5 = #ttg.linear <{register = [[0 , 0 , 0 , 0 , 1 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 , 1 , 0 ]], lane = [[0 , 0 , 0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 2 , 0 , 0 , 0 ], [0 , 0 , 0 , 4 , 0 , 0 , 0 ], [0 , 0 , 0 , 8 , 0 , 0 , 0 ], [0 , 0 , 1 , 0 , 0 , 0 , 0 ], [0 , 0 , 2 , 0 , 0 , 0 , 0 ]], warp = [[1 , 0 , 0 , 0 , 0 , 0 , 0 ], [2 , 0 , 0 , 0 , 0 , 0 , 0 ]], block = []}>
945+ #linear6 = #ttg.linear <{register = [[0 , 0 , 0 , 0 , 1 , 0 , 0 ], [0 , 1 , 0 , 0 , 0 , 0 , 0 ]], lane = [[0 , 0 , 1 , 0 , 0 , 0 , 0 ], [0 , 0 , 2 , 0 , 0 , 0 , 0 ], [0 , 0 , 4 , 0 , 0 , 0 , 0 ], [0 , 0 , 8 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 , 1 , 0 ], [0 , 0 , 0 , 0 , 0 , 2 , 0 ]], warp = [[1 , 0 , 0 , 0 , 0 , 0 , 0 ], [2 , 0 , 0 , 0 , 0 , 0 , 0 ]], block = []}>
946+ #linear7 = #ttg.linear <{register = [[0 , 4 ], [16 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 1 ], [0 , 2 ]], warp = [[32 , 0 ], [64 , 0 ]], block = []}>
947+ #linear8 = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [0 , 1024 ], [1 , 0 ]], lane = [[0 , 16 ], [0 , 32 ], [0 , 64 ], [0 , 128 ], [0 , 256 ], [0 , 512 ]], warp = [[2 , 0 ], [4 , 0 ]], block = []}>
948+ #linear9 = #ttg.linear <{register = [[0 , 0 , 0 , 0 , 0 , 1 ], [0 , 0 , 0 , 0 , 0 , 2 ], [0 , 0 , 0 , 0 , 0 , 4 ], [0 , 0 , 0 , 0 , 0 , 8 ], [0 , 0 , 4 , 0 , 0 , 0 ], [0 , 1 , 0 , 0 , 0 , 0 ]], lane = [[0 , 0 , 0 , 0 , 1 , 0 ], [0 , 0 , 0 , 0 , 2 , 0 ], [0 , 0 , 0 , 0 , 4 , 0 ], [0 , 0 , 0 , 0 , 8 , 0 ], [0 , 0 , 1 , 0 , 0 , 0 ], [0 , 0 , 2 , 0 , 0 , 0 ]], warp = [[0 , 2 , 0 , 0 , 0 , 0 ], [0 , 4 , 0 , 0 , 0 , 0 ]], block = []}>
949+ #linear10 = #ttg.linear <{register = [[0 , 0 , 0 , 0 , 0 , 1 ], [0 , 0 , 0 , 0 , 0 , 2 ], [0 , 0 , 0 , 0 , 0 , 4 ], [0 , 0 , 0 , 0 , 0 , 8 ], [0 , 0 , 0 , 4 , 0 , 0 ], [0 , 1 , 0 , 0 , 0 , 0 ]], lane = [[0 , 0 , 1 , 0 , 0 , 0 ], [0 , 0 , 2 , 0 , 0 , 0 ], [0 , 0 , 4 , 0 , 0 , 0 ], [0 , 0 , 8 , 0 , 0 , 0 ], [0 , 0 , 0 , 1 , 0 , 0 ], [0 , 0 , 0 , 2 , 0 , 0 ]], warp = [[0 , 2 , 0 , 0 , 0 , 0 ], [0 , 4 , 0 , 0 , 0 , 0 ]], block = []}>
950+ #linear11 = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [0 , 64 ], [16 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 16 ], [0 , 32 ]], warp = [[32 , 0 ], [64 , 0 ]], block = []}>
951+ #mma = #ttg.amd_mfma <{version = 4 , warpsPerCTA = [1 , 4 ], tilesPerWarp = [2 , 2 ], instrShape = [16 , 16 ], isTransposed = true }>
952+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
953+ tt.func public @bypass_lds_b_operand (%a_ptr: !tt.ptr <i8 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %b_ptr: !tt.ptr <i8 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %c_ptr: !tt.ptr <bf16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %a_scales_ptr: !tt.ptr <i8 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %b_scales_ptr: !tt.ptr <i8 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %M: i32 {tt.divisibility = 16 : i32 }, %N: i32 {tt.divisibility = 16 : i32 }, %K: i32 {tt.divisibility = 16 : i32 }, %stride_am: i32 {tt.divisibility = 16 : i32 }, %stride_bn: i32 {tt.divisibility = 16 : i32 }, %stride_ck: i32 {tt.divisibility = 16 : i32 }, %stride_cm: i32 {tt.divisibility = 16 : i32 }, %stride_asm: i32 {tt.divisibility = 16 : i32 }, %stride_bsn: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
954+ %cst = arith.constant dense <128 > : tensor <32 x128 xi32 , #blocked >
955+ %cst_0 = arith.constant dense <2048 > : tensor <8 x2048 xi32 , #blocked1 >
956+ %cst_1 = arith.constant dense <256 > : tensor <4 x256 xi32 , #blocked2 >
957+ %c1_i32 = arith.constant 1 : i32
958+ %pid_unified = arith.constant 7 : i32
959+ %c64_i32 = arith.constant 64 : i32
960+ %num_pid_n = arith.constant 127 : i32
961+ %cst_2 = arith.constant dense <256 > : tensor <1 x256 xi32 , #blocked3 >
962+ %c128_i32 = arith.constant 128 : i32
963+ %c0_i32 = arith.constant 0 : i32
964+ %c32_i32 = arith.constant 32 : i32
965+ %c8_i32 = arith.constant 8 : i32
966+ %c4_i32 = arith.constant 4 : i32
967+ %cst_3 = arith.constant dense <0.000000e+00 > : tensor <32 x128 xf32 , #mma >
968+ %pid_unified_4 = tt.get_program_id x : i32
969+ %xcd = arith.remsi %pid_unified_4 , %c8_i32 : i32
970+ %local_pid = arith.divsi %pid_unified_4 , %c8_i32 : i32
971+ %pid = arith.muli %xcd , %c8_i32 : i32
972+ %pid_9 = arith.addi %pid , %local_pid : i32
973+ %num_pid_n_7 = arith.addi %N , %num_pid_n : i32
974+ %num_pid_n_8 = arith.divsi %num_pid_n_7 , %c128_i32 : i32
975+ %pid_n = arith.remsi %pid_9 , %num_pid_n_8 : i32
976+ %offs_bn = arith.muli %pid_n , %c8_i32 : i32
977+ %offs_bn_15 = tt.make_range {end = 8 : i32 , start = 0 : i32 } : tensor <8 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
978+ %offs_bn_16 = tt.splat %offs_bn : i32 -> tensor <8 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
979+ %offs_bn_17 = arith.addi %offs_bn_16 , %offs_bn_15 : tensor <8 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
980+ %offs_bn_18 = tt.splat %N : i32 -> tensor <8 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
981+ %offs_bn_19 = arith.remsi %offs_bn_17 , %offs_bn_18 : tensor <8 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
982+ %a_ptrs_28 = tt.splat %a_ptr : !tt.ptr <i8 > -> tensor <32 x128 x!tt.ptr <i8 >, #blocked >
983+ %b_ptrs = tt.expand_dims %offs_bn_19 {axis = 1 : i32 } : tensor <8 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <8 x1 xi32 , #blocked1 >
984+ %b_ptrs_29 = tt.splat %stride_bn : i32 -> tensor <8 x1 xi32 , #blocked1 >
985+ %b_ptrs_30 = arith.muli %b_ptrs , %b_ptrs_29 : tensor <8 x1 xi32 , #blocked1 >
986+ %b_ptrs_31 = tt.make_range {end = 2048 : i32 , start = 0 : i32 } : tensor <2048 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
987+ %b_ptrs_32 = tt.expand_dims %b_ptrs_31 {axis = 0 : i32 } : tensor <2048 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x2048 xi32 , #blocked1 >
988+ %b_ptrs_33 = tt.broadcast %b_ptrs_30 : tensor <8 x1 xi32 , #blocked1 > -> tensor <8 x2048 xi32 , #blocked1 >
989+ %b_ptrs_34 = tt.broadcast %b_ptrs_32 : tensor <1 x2048 xi32 , #blocked1 > -> tensor <8 x2048 xi32 , #blocked1 >
990+ %b_ptrs_35 = arith.addi %b_ptrs_33 , %b_ptrs_34 : tensor <8 x2048 xi32 , #blocked1 >
991+ %b_ptrs_36 = tt.splat %b_ptr : !tt.ptr <i8 > -> tensor <8 x2048 x!tt.ptr <i8 >, #blocked1 >
992+ %b_ptrs_37 = tt.addptr %b_ptrs_36 , %b_ptrs_35 : tensor <8 x2048 x!tt.ptr <i8 >, #blocked1 >, tensor <8 x2048 xi32 , #blocked1 >
993+ %b_scale_ptrs_53 = tt.splat %b_scales_ptr : !tt.ptr <i8 > -> tensor <4 x256 x!tt.ptr <i8 >, #blocked2 >
994+ %a_scale_ptrs_56 = tt.splat %a_scales_ptr : !tt.ptr <i8 > -> tensor <1 x256 x!tt.ptr <i8 >, #blocked3 >
995+ %accumulator:5 = scf.for %accumulator_83 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%a_scale_ptrs_84 = %a_scale_ptrs_56 , %arg16 = %cst_3 , %b_scale_ptrs_85 = %b_scale_ptrs_53 , %a_ptrs_86 = %a_ptrs_28 , %b_ptrs_87 = %b_ptrs_37 ) -> (tensor <1 x256 x!tt.ptr <i8 >, #blocked3 >, tensor <32 x128 xf32 , #mma >, tensor <4 x256 x!tt.ptr <i8 >, #blocked2 >, tensor <32 x128 x!tt.ptr <i8 >, #blocked >, tensor <8 x2048 x!tt.ptr <i8 >, #blocked1 >) : i32 {
996+ %a_scales = tt.load %a_scale_ptrs_84 : tensor <1 x256 x!tt.ptr <i8 >, #blocked3 >
997+ %a_scales_88 = ttg.convert_layout %a_scales : tensor <1 x256 xi8 , #blocked3 > -> tensor <1 x256 xi8 , #linear >
998+ %a_scales_89 = tt.reshape %a_scales_88 : tensor <1 x256 xi8 , #linear > -> tensor <1 x1 x4 x16 x2 x2 x1 xi8 , #linear1 >
999+ %a_scales_90 = tt.trans %a_scales_89 {order = array<i32 : 0 , 5 , 3 , 1 , 4 , 2 , 6 >} : tensor <1 x1 x4 x16 x2 x2 x1 xi8 , #linear1 > -> tensor <1 x2 x16 x1 x2 x4 x1 xi8 , #linear2 >
1000+ %a_scales_91 = tt.reshape %a_scales_90 : tensor <1 x2 x16 x1 x2 x4 x1 xi8 , #linear2 > -> tensor <32 x8 xi8 , #linear3 >
1001+ %b_scales = tt.load %b_scale_ptrs_85 : tensor <4 x256 x!tt.ptr <i8 >, #blocked2 >
1002+ %b_scales_92 = ttg.convert_layout %b_scales : tensor <4 x256 xi8 , #blocked2 > -> tensor <4 x256 xi8 , #linear4 >
1003+ %b_scales_93 = tt.reshape %b_scales_92 : tensor <4 x256 xi8 , #linear4 > -> tensor <4 x1 x4 x16 x2 x2 x1 xi8 , #linear5 >
1004+ %b_scales_94 = tt.trans %b_scales_93 {order = array<i32 : 0 , 5 , 3 , 1 , 4 , 2 , 6 >} : tensor <4 x1 x4 x16 x2 x2 x1 xi8 , #linear5 > -> tensor <4 x2 x16 x1 x2 x4 x1 xi8 , #linear6 >
1005+ %b_scales_95 = tt.reshape %b_scales_94 : tensor <4 x2 x16 x1 x2 x4 x1 xi8 , #linear6 > -> tensor <128 x8 xi8 , #linear7 >
1006+ %a = tt.load %a_ptrs_86 : tensor <32 x128 x!tt.ptr <i8 >, #blocked >
1007+ %b = tt.load %b_ptrs_87 : tensor <8 x2048 x!tt.ptr <i8 >, #blocked1 >
1008+ %accumulator_96 = ttg.convert_layout %b : tensor <8 x2048 xi8 , #blocked1 > -> tensor <8 x2048 xi8 , #linear8 >
1009+ %b_97 = tt.reshape %accumulator_96 : tensor <8 x2048 xi8 , #linear8 > -> tensor <1 x8 x8 x1 x16 x16 xi8 , #linear9 >
1010+ %b_98 = tt.trans %b_97 {order = array<i32 : 0 , 1 , 4 , 2 , 3 , 5 >} : tensor <1 x8 x8 x1 x16 x16 xi8 , #linear9 > -> tensor <1 x8 x16 x8 x1 x16 xi8 , #linear10 >
1011+ %b_99 = tt.reshape %b_98 : tensor <1 x8 x16 x8 x1 x16 xi8 , #linear10 > -> tensor <128 x128 xi8 , #linear11 >
1012+ %b_100 = tt.trans %b_99 {order = array<i32 : 1 , 0 >} : tensor <128 x128 xi8 , #linear11 > -> tensor <128 x128 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>
1013+ %a_101 = ttg.convert_layout %a : tensor <32 x128 xi8 , #blocked > -> tensor <32 x128 xi8 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>
1014+ %accumulator_102 = tt.dot_scaled %a_101 scale %a_scales_91 , %b_100 scale %b_scales_95 , %cst_3 lhs = e2m1 rhs = e2m1 {fastMath = false } : tensor <32 x128 xi8 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>, tensor <32 x8 xi8 , #linear3 > * tensor <128 x128 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>, tensor <128 x8 xi8 , #linear7 > -> tensor <32 x128 xf32 , #mma >
1015+ %accumulator_103 = arith.addf %arg16 , %accumulator_102 : tensor <32 x128 xf32 , #mma >
1016+ %a_ptrs_104 = tt.addptr %a_ptrs_86 , %cst : tensor <32 x128 x!tt.ptr <i8 >, #blocked >, tensor <32 x128 xi32 , #blocked >
1017+ %b_ptrs_105 = tt.addptr %b_ptrs_87 , %cst_0 : tensor <8 x2048 x!tt.ptr <i8 >, #blocked1 >, tensor <8 x2048 xi32 , #blocked1 >
1018+ %a_scale_ptrs_106 = tt.addptr %a_scale_ptrs_84 , %cst_2 : tensor <1 x256 x!tt.ptr <i8 >, #blocked3 >, tensor <1 x256 xi32 , #blocked3 >
1019+ %b_scale_ptrs_107 = tt.addptr %b_scale_ptrs_85 , %cst_1 : tensor <4 x256 x!tt.ptr <i8 >, #blocked2 >, tensor <4 x256 xi32 , #blocked2 >
1020+ scf.yield %a_scale_ptrs_106 , %accumulator_103 , %b_scale_ptrs_107 , %a_ptrs_104 , %b_ptrs_105 : tensor <1 x256 x!tt.ptr <i8 >, #blocked3 >, tensor <32 x128 xf32 , #mma >, tensor <4 x256 x!tt.ptr <i8 >, #blocked2 >, tensor <32 x128 x!tt.ptr <i8 >, #blocked >, tensor <8 x2048 x!tt.ptr <i8 >, #blocked1 >
1021+ }
1022+ tt.return
1023+ }
1024+ }
0 commit comments