@@ -70,8 +70,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
70
70
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 } {
71
71
// COM: tt.load -> tt.trans -> tt.dot chain, in a loop.
72
72
// COM: where the 'make_tensor_ptr' result is loop carried.
73
- tt.func public @fuseLoadWithTrans3 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr <f32 >) {
74
- %c4_i32 = arith.constant 4 : i32
73
+ tt.func public @fuseLoadWithTrans3 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
75
74
%c1024_i32 = arith.constant 1024 : i32
76
75
%c0_i32 = arith.constant 0 : i32
77
76
%c32_i32 = arith.constant 32 : i32
@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
80
79
%c1_i64 = arith.constant 1 : i64
81
80
%c1024_i64 = arith.constant 1024 : i64
82
81
%cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
83
- %0 = tt.get_program_id x : i32
84
- %1 = arith.divsi %0 , %c16_i32 : i32
85
- %2 = arith.muli %1 , %c4_i32 : i32
86
- %3 = arith.subi %c4_i32 , %2 : i32
87
- %4 = arith.minsi %3 , %c4_i32 : i32
88
- %5 = arith.remsi %0 , %c16_i32 : i32
89
- %6 = arith.remsi %5 , %4 : i32
90
- %7 = arith.addi %2 , %6 : i32
91
- %8 = arith.divsi %5 , %4 : i32
92
82
%9 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c1024_i64 ], [%c1024_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
93
83
%10 = tt.make_tensor_ptr %arg1 , [%c1024_i64 , %c1_i64 ], [%c1_i64 , %c1024_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #linear >>
94
84
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %c0_i32 , %arg6 = %10 ) -> (tensor <256 x256 xf32 , #mma >, i32 , !tt.ptr <tensor <256 x32 xbf16 , #linear >>) : i32 {
@@ -116,13 +106,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
116
106
117
107
// -----
118
108
109
+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [16 , 0 ], [0 , 16 ], [0 , 32 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
110
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [4 , 1 ], repCluster = [2 , 2 ], A = [16 , 16 ], B = [16 , 32 ], C = [16 , 32 ]}>
111
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 } {
112
+ // COM: tt.load -> tt.trans -> tt.dot chain, in 2 loops.
113
+ // COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
114
+ tt.func public @fuseLoadWithTrans4 (%arg0: i32 , %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >) {
115
+ %c32_i32 = arith.constant 32 : i32
116
+ %c0_i32 = arith.constant 0 : i32
117
+ %c64_i64 = arith.constant 64 : i64
118
+ %c1_i64 = arith.constant 1 : i64
119
+ %cst_3 = arith.constant dense <0.000000e+00 > : tensor <64 x32 xf32 , #mma >
120
+ %7 = tt.make_tensor_ptr %arg1 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
121
+ %9 = tt.make_tensor_ptr %arg2 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x64 xf16 , #linear >>
122
+ %24 = tt.advance %7 , [%arg0 , %c0_i32 ] : <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
123
+ %25 = tt.load %24 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
124
+ %29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg13 = %arg0 ) -> (i32 ) : i32 {
125
+ %adv1 = tt.advance %9 , [%arg13 , %c0_i32 ] : <tensor <32 x64 xf16 , #linear >>
126
+ %load1 = tt.load %adv1 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x64 xf16 , #linear >>
127
+ %trans1 = tt.trans %load1 {order = array<i32 : 1 , 0 >} : tensor <32 x64 xf16 , #linear > -> tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
128
+ %dot1 = tt.dot %25 , %trans1 , %cst_3 , inputPrecision = tf32 : tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x32 xf32 , #mma >
129
+ %76 = arith.addi %arg13 , %c32_i32 : i32
130
+ scf.yield %76 : i32
131
+ }
132
+ %38:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg13 = %arg0 ) -> (i32 ) : i32 {
133
+ %adv2 = tt.advance %9 , [%arg13 , %c0_i32 ] : <tensor <32 x64 xf16 , #linear >>
134
+ %load2 = tt.load %adv2 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x64 xf16 , #linear >>
135
+ %trans2 = tt.trans %load2 {order = array<i32 : 1 , 0 >} : tensor <32 x64 xf16 , #linear > -> tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
136
+ %dot2 = tt.dot %25 , %trans2 , %cst_3 , inputPrecision = tf32 : tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x32 xf32 , #mma >
137
+ %81 = arith.addi %arg13 , %c32_i32 : i32
138
+ scf.yield %81 : i32
139
+ }
140
+ tt.return
141
+ }
142
+ // CHECK-LABEL: fuseLoadWithTrans4
143
+ // CHECK-NOT: tt.trans
144
+ // CHECK-COUNT-2: tt.make_tensor_ptr %arg2, [%c64_i64, %c1_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
145
+ // CHECK: scf.for {{.*}}
146
+ // CHECK: [[ADV1:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
147
+ // CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV1]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
148
+ // CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
149
+ // CHECK: scf.yield {{.*}}
150
+ // CHECK: scf.for {{.*}}
151
+ // CHECK: [[ADV2:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
152
+ // CHECK: [[LOAD_B2:%.*]] = tt.load [[ADV2]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
153
+ // CHECK: tt.dot {{.*}}, [[LOAD_B2]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
154
+ // CHECK: scf.yield {{.*}}
155
+ }
156
+
157
+ // -----
158
+
119
159
#linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [16 , 0 ], [0 , 16 ], [128 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ]], warp = [[32 , 0 ], [64 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]], block = []}>
120
160
#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
121
161
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 } {
122
- // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load
123
- // COM: that 'feeds' the transpose operation is used.
124
- tt.func public @doNotFuseLoadWithTrans1 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr <f32 >) {
125
- %c4_i32 = arith.constant 4 : i32
162
+ // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load that 'feeds' the transpose operation is used.
163
+ tt.func public @doNotFuseLoadWithTrans1 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
126
164
%c1024_i32 = arith.constant 1024 : i32
127
165
%c0_i32 = arith.constant 0 : i32
128
166
%c32_i32 = arith.constant 32 : i32
@@ -131,15 +169,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
131
169
%c1_i64 = arith.constant 1 : i64
132
170
%c1024_i64 = arith.constant 1024 : i64
133
171
%cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
134
- %0 = tt.get_program_id x : i32
135
- %1 = arith.divsi %0 , %c16_i32 : i32
136
- %2 = arith.muli %1 , %c4_i32 : i32
137
- %3 = arith.subi %c4_i32 , %2 : i32
138
- %4 = arith.minsi %3 , %c4_i32 : i32
139
- %5 = arith.remsi %0 , %c16_i32 : i32
140
- %6 = arith.remsi %5 , %4 : i32
141
- %7 = arith.addi %2 , %6 : i32
142
- %8 = arith.divsi %5 , %4 : i32
143
172
%9 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c1024_i64 ], [%c1024_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
144
173
%10 = tt.make_tensor_ptr %arg1 , [%c1024_i64 , %c1_i64 ], [%c1_i64 , %c1024_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #linear >>
145
174
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %c0_i32 , %arg6 = %10 ) -> (tensor <256 x256 xf32 , #mma >, i32 , !tt.ptr <tensor <256 x32 xbf16 , #linear >>) : i32 {
@@ -166,7 +195,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
166
195
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 } {
167
196
// COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose.
168
197
// COM: In this case `%19` is used by the load that feeds the transpose and by a store operation.
169
- tt.func public @doNotFuseLoadWithTrans2 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr < f32 > ) {
198
+ tt.func public @doNotFuseLoadWithTrans2 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
170
199
%c4_i32 = arith.constant 4 : i32
171
200
%c1024_i32 = arith.constant 1024 : i32
172
201
%c0_i32 = arith.constant 0 : i32
0 commit comments