@@ -11,13 +11,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
11
11
%cst = arith.constant dense <0.000000e+00 > : tensor <128 x256 xf16 , #mma >
12
12
13
13
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
14
- // CHECK-DAG: %[[CST_16384:.*]] = llvm.mlir.constant(16384 : i32) : i32
15
14
// CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32
15
+ // CHECK-DAG: %[[CST_387:.*]] = llvm.mlir.constant(387 : i32) : i32
16
16
// CHECK-DAG: %[[CST_384:.*]] = llvm.mlir.constant(384 : i32) : i32
17
- // CHECK-DAG: %[[CST_112:.*]] = llvm.mlir.constant(112 : i32) : i32
17
+ // CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
18
+ // CHECK-DAG: %[[CST_48:.*]] = llvm.mlir.constant(48 : i32) : i32
18
19
// CHECK-DAG: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32
19
- // CHECK-DAG: %[[CST_8 :.*]] = llvm.mlir.constant(8 : i32) : i32
20
- // CHECK-DAG: %[[CST_6 :.*]] = llvm.mlir.constant(6 : i32) : i32
20
+ // CHECK-DAG: %[[CST_14 :.*]] = llvm.mlir.constant(14 : i32) : i32
21
+ // CHECK-DAG: %[[CST_12 :.*]] = llvm.mlir.constant(12 : i32) : i32
21
22
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
22
23
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
23
24
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -35,36 +36,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
35
36
// CHECK: %[[VAL_26:.*]] = llvm.or %[[CST_0]], %[[VAL_25]] : i32
36
37
// CHECK: %[[VAL_27:.*]] = llvm.shl %[[warpId]], %[[CST_4]] : i32
37
38
// CHECK: %[[VAL_28:.*]] = llvm.or %[[VAL_26]], %[[VAL_27]] : i32
38
- // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_384 ]] : i32
39
- // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_6 ]] : i32
39
+ // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_3 ]] : i32
40
+ // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_14 ]] : i32
40
41
// CHECK: %[[VAL_31:.*]] = llvm.xor %[[CST_0]], %[[VAL_30]] : i32
41
- // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_112 ]] : i32
42
- // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_1 ]] : i32
42
+ // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_387 ]] : i32
43
+ // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_4 ]] : i32
43
44
// CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_31]], %[[VAL_33]] : i32
44
- // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_15 ]] : i32
45
- // CHECK: %[[VAL_36:.*]] = llvm.lshr %[[VAL_35]], %[[CST_0 ]] : i32
45
+ // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_48 ]] : i32
46
+ // CHECK: %[[VAL_36:.*]] = llvm.shl %[[VAL_35]], %[[CST_1 ]] : i32
46
47
// CHECK: %[[VAL_37:.*]] = llvm.xor %[[VAL_34]], %[[VAL_36]] : i32
47
- // CHECK: %[[VAL_38:.*]] = llvm.xor %[[CST_0]], %[[VAL_37]] : i32
48
- // CHECK: %[[VAL_39:.*]] = llvm.and %[[VAL_28]], %[[CST_511]] : i32
49
- // CHECK: %[[VAL_40:.*]] = llvm.shl %[[VAL_39]], %[[CST_3]] : i32
50
- // CHECK: %[[VAL_41:.*]] = llvm.xor %[[CST_0]], %[[VAL_40]] : i32
51
- // CHECK: %[[VAL_42:.*]] = llvm.xor %[[CST_0]], %[[VAL_41]] : i32
52
- // CHECK: %[[VAL_43:.*]] = llvm.xor %[[VAL_38]], %[[CST_0]] : i32
53
- // CHECK: %[[VAL_44:.*]] = llvm.lshr %[[VAL_43]], %[[CST_8]] : i32
54
- // CHECK: %[[VAL_45:.*]] = llvm.shl %[[VAL_44]], %[[CST_3]] : i32
55
- // CHECK: %[[offset:.*]] = llvm.add %[[VAL_45]], %[[VAL_43]] : i32
56
- // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
57
- // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<1xf16>
58
-
59
- // COM: Because the values per thread of DPAS layout is not contiguous. The values are stored in the SLM in a non-vectorized way.
60
- // COM: Total 64 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16) = 64
61
- // CHECK: llvm.store %[[VAL_66]], %[[VAL_65]] : vector<1xf16>, !llvm.ptr<3>
62
- // CHECK-COUNT-63: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3>
48
+ // CHECK: %[[VAL_38:.*]] = llvm.and %[[VAL_28]], %[[CST_12]] : i32
49
+ // CHECK: %[[VAL_39:.*]] = llvm.lshr %[[VAL_38]], %[[CST_0]] : i32
50
+ // CHECK: %[[VAL_40:.*]] = llvm.xor %[[VAL_37]], %[[VAL_39]] : i32
51
+ // CHECK: %[[VAL_41:.*]] = llvm.and %[[VAL_28]], %[[CST_64]] : i32
52
+ // CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32
53
+ // CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8192]] : i1, i32
54
+ // CHECK: %[[VAL_44:.*]] = llvm.xor %[[VAL_40]], %[[VAL_43]] : i32
55
+ // CHECK: %[[VAL_45:.*]] = llvm.xor %[[CST_0]], %[[VAL_44]] : i32
56
+ // CHECK: %[[VAL_46:.*]] = llvm.mul %[[CST_0]], %[[CST_2]] : i32
57
+ // CHECK: %[[VAL_47:.*]] = llvm.xor %[[VAL_45]], %[[VAL_46]] : i32
58
+ // CHECK: %[[VAL_48:.*]] = llvm.xor %[[VAL_47]], %[[CST_0]] : i32
59
+ // CHECK: %[[offset:.*]] = llvm.add %[[VAL_48]], %[[CST_0]] : i32
60
+ // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
61
+ // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<2xf16>
62
+ // CHECK: %[[VAL_67:.*]] = llvm.insertelement {{.*}}, %[[VAL_66]]{{\[}}%[[CST_1]] : i32] : vector<2xf16>
63
+
64
+ // COM: Because the values per thread of DPAS layout is contiguous. The values are stored in the SLM in vectorized way.
65
+ // COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16*2) = 32
66
+ // CHECK: llvm.store %[[VAL_67]], %[[VAL_65]] : vector<2xf16>, !llvm.ptr<3>
67
+ // CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<2xf16>, !llvm.ptr<3>
63
68
// CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
64
69
65
70
// COM: Because the values per thread of blocked layout is contiguous. The values are loaded from the SLM in a vectorized way.
66
71
// COM: Total 8 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*8) = 8
67
- // CHECK-COUNT-8 : {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16 >
72
+ // CHECK-COUNT-4 : {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16 >
68
73
69
74
%93 = ttg.convert_layout %cst {allocation.offset = 0 : i32 } : tensor <128 x256 xf16 , #mma > -> tensor <128 x256 xf16 , #blocked >
70
75
%80 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked >
@@ -90,11 +95,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
90
95
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
91
96
// CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32
92
97
// CHECK-DAG: %[[CST_4096:.*]] = llvm.mlir.constant(4096 : i32) : i32
98
+ // CHECK-DAG: %[[CST_387:.*]] = llvm.mlir.constant(387 : i32) : i32
93
99
// CHECK-DAG: %[[CST_384:.*]] = llvm.mlir.constant(384 : i32) : i32
94
- // CHECK-DAG: %[[CST_112:.*]] = llvm.mlir.constant(112 : i32) : i32
100
+ // CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
101
+ // CHECK-DAG: %[[CST_48:.*]] = llvm.mlir.constant(48 : i32) : i32
95
102
// CHECK-DAG: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32
96
- // CHECK-DAG: %[[CST_8 :.*]] = llvm.mlir.constant(8 : i32) : i32
97
- // CHECK-DAG: %[[CST_5 :.*]] = llvm.mlir.constant(5 : i32) : i32
103
+ // CHECK-DAG: %[[CST_14 :.*]] = llvm.mlir.constant(14 : i32) : i32
104
+ // CHECK-DAG: %[[CST_12 :.*]] = llvm.mlir.constant(12 : i32) : i32
98
105
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
99
106
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
100
107
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -113,43 +120,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
113
120
// CHECK: %[[VAL_26:.*]] = llvm.or %[[CST_0]], %[[VAL_25]] : i32
114
121
// CHECK: %[[VAL_27:.*]] = llvm.shl %[[warpId]], %[[CST_4]] : i32
115
122
// CHECK: %[[VAL_28:.*]] = llvm.or %[[VAL_26]], %[[VAL_27]] : i32
116
- // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_384 ]] : i32
117
- // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_5 ]] : i32
123
+ // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_3 ]] : i32
124
+ // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_14 ]] : i32
118
125
// CHECK: %[[VAL_31:.*]] = llvm.xor %[[CST_0]], %[[VAL_30]] : i32
119
- // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_112 ]] : i32
120
- // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_1 ]] : i32
126
+ // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_387 ]] : i32
127
+ // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_4 ]] : i32
121
128
// CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_31]], %[[VAL_33]] : i32
122
- // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_15 ]] : i32
123
- // CHECK: %[[VAL_36:.*]] = llvm.lshr %[[VAL_35]], %[[CST_0 ]] : i32
129
+ // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_48 ]] : i32
130
+ // CHECK: %[[VAL_36:.*]] = llvm.shl %[[VAL_35]], %[[CST_1 ]] : i32
124
131
// CHECK: %[[VAL_37:.*]] = llvm.xor %[[VAL_34]], %[[VAL_36]] : i32
125
- // CHECK: %[[VAL_38:.*]] = llvm.xor %[[CST_0]], %[[VAL_37]] : i32
126
- // CHECK: %[[VAL_39:.*]] = llvm.and %[[VAL_28]], %[[CST_511]] : i32
127
- // CHECK: %[[VAL_40:.*]] = llvm.shl %[[VAL_39]], %[[CST_3]] : i32
128
- // CHECK: %[[VAL_41:.*]] = llvm.xor %[[CST_0]], %[[VAL_40]] : i32
129
- // CHECK: %[[VAL_42:.*]] = llvm.xor %[[CST_0]], %[[VAL_41]] : i32
130
- // CHECK: %[[VAL_43:.*]] = llvm.xor %[[VAL_38]], %[[CST_0]] : i32
131
- // CHECK: %[[VAL_44:.*]] = llvm.lshr %[[VAL_43]], %[[CST_8]] : i32
132
- // CHECK: %[[VAL_45:.*]] = llvm.shl %[[VAL_44]], %[[CST_3]] : i32
133
- // CHECK: %[[offset:.*]] = llvm.add %[[VAL_45]], %[[VAL_43]] : i32
134
- // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
135
- // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<1xf16>
136
-
137
- // COM: Because the values per thread of DPAS layout is not contiguous. The values are stored in the SLM in a non-vectorized way.
138
- // COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 64*256/(4*8*16) = 32
139
- // CHECK: llvm.store %[[VAL_66]], %[[VAL_65]] : vector<1xf16>, !llvm.ptr<3>
140
- // CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3>
132
+ // CHECK: %[[VAL_38:.*]] = llvm.and %[[VAL_28]], %[[CST_12]] : i32
133
+ // CHECK: %[[VAL_39:.*]] = llvm.lshr %[[VAL_38]], %[[CST_0]] : i32
134
+ // CHECK: %[[VAL_40:.*]] = llvm.xor %[[VAL_37]], %[[VAL_39]] : i32
135
+ // CHECK: %[[VAL_41:.*]] = llvm.and %[[VAL_28]], %[[CST_64]] : i32
136
+ // CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32
137
+ // CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8192]] : i1, i32
138
+ // CHECK: %[[VAL_44:.*]] = llvm.xor %[[VAL_40]], %[[VAL_43]] : i32
139
+ // CHECK: %[[VAL_45:.*]] = llvm.xor %[[CST_0]], %[[VAL_44]] : i32
140
+ // CHECK: %[[VAL_46:.*]] = llvm.mul %[[CST_0]], %[[CST_2]] : i32
141
+ // CHECK: %[[VAL_47:.*]] = llvm.xor %[[VAL_45]], %[[VAL_46]] : i32
142
+ // CHECK: %[[VAL_48:.*]] = llvm.xor %[[VAL_47]], %[[CST_0]] : i32
143
+ // CHECK: %[[offset:.*]] = llvm.add %[[VAL_48]], %[[CST_0]] : i32
144
+ // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
145
+ // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<2xf16>
146
+ // CHECK: %[[VAL_67:.*]] = llvm.insertelement {{.*}}, %[[VAL_66]]{{\[}}%[[CST_1]] : i32] : vector<2xf16>
147
+
148
+ // COM: Because the values per thread of DPAS layout is contiguous. The values are stored in the SLM in vectorized way.
149
+ // COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16*2) = 32
150
+ // CHECK: llvm.store %[[VAL_67]], %[[VAL_65]] : vector<2xf16>, !llvm.ptr<3>
151
+ // CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<2xf16>, !llvm.ptr<3>
141
152
// CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
142
153
143
154
// COM: Because the values per thread of blocked layout is contiguous. The values are loaded from the SLM in a vectorized way.
144
- // COM: Total 4 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*8) = 8
145
- // CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
146
-
147
- // COM: The 2nd round of exchanging values.
148
- // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
149
- // CHECK-COUNT-32: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3>
150
- // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
151
- // CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
152
-
155
+ // COM: Total 16 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*4) = 16
156
+ // CHECK-COUNT-16: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16>
153
157
%93 = ttg.convert_layout %cst {allocation.offset = 0 : i32 } : tensor <128 x256 xf16 , #mma > -> tensor <128 x256 xf16 , #blocked >
154
158
%80 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked >
155
159
%83 = tt.broadcast %80 : tensor <128 x1 x!tt.ptr <f16 >, #blocked > -> tensor <128 x256 x!tt.ptr <f16 >, #blocked >
0 commit comments