Skip to content

Commit b3d648d

Browse files
wdziurdzwhitneywhtsang
authored andcommitted
Adjust LIT tests for the swizzling path
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 0c4fd9c commit b3d648d

File tree

2 files changed

+75
-71
lines changed

2 files changed

+75
-71
lines changed

test/Conversion/intel/dpas_to_block_layout_convert.mlir

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
1111
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #mma>
1212

1313
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
14-
// CHECK-DAG: %[[CST_16384:.*]] = llvm.mlir.constant(16384 : i32) : i32
1514
// CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32
15+
// CHECK-DAG: %[[CST_387:.*]] = llvm.mlir.constant(387 : i32) : i32
1616
// CHECK-DAG: %[[CST_384:.*]] = llvm.mlir.constant(384 : i32) : i32
17-
// CHECK-DAG: %[[CST_112:.*]] = llvm.mlir.constant(112 : i32) : i32
17+
// CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
18+
// CHECK-DAG: %[[CST_48:.*]] = llvm.mlir.constant(48 : i32) : i32
1819
// CHECK-DAG: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32
19-
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
20-
// CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32
20+
// CHECK-DAG: %[[CST_14:.*]] = llvm.mlir.constant(14 : i32) : i32
21+
// CHECK-DAG: %[[CST_12:.*]] = llvm.mlir.constant(12 : i32) : i32
2122
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
2223
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
2324
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -35,36 +36,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
3536
// CHECK: %[[VAL_26:.*]] = llvm.or %[[CST_0]], %[[VAL_25]] : i32
3637
// CHECK: %[[VAL_27:.*]] = llvm.shl %[[warpId]], %[[CST_4]] : i32
3738
// CHECK: %[[VAL_28:.*]] = llvm.or %[[VAL_26]], %[[VAL_27]] : i32
38-
// CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_384]] : i32
39-
// CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_6]] : i32
39+
// CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_3]] : i32
40+
// CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_14]] : i32
4041
// CHECK: %[[VAL_31:.*]] = llvm.xor %[[CST_0]], %[[VAL_30]] : i32
41-
// CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_112]] : i32
42-
// CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_1]] : i32
42+
// CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_387]] : i32
43+
// CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_4]] : i32
4344
// CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_31]], %[[VAL_33]] : i32
44-
// CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_15]] : i32
45-
// CHECK: %[[VAL_36:.*]] = llvm.lshr %[[VAL_35]], %[[CST_0]] : i32
45+
// CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_48]] : i32
46+
// CHECK: %[[VAL_36:.*]] = llvm.shl %[[VAL_35]], %[[CST_1]] : i32
4647
// CHECK: %[[VAL_37:.*]] = llvm.xor %[[VAL_34]], %[[VAL_36]] : i32
47-
// CHECK: %[[VAL_38:.*]] = llvm.xor %[[CST_0]], %[[VAL_37]] : i32
48-
// CHECK: %[[VAL_39:.*]] = llvm.and %[[VAL_28]], %[[CST_511]] : i32
49-
// CHECK: %[[VAL_40:.*]] = llvm.shl %[[VAL_39]], %[[CST_3]] : i32
50-
// CHECK: %[[VAL_41:.*]] = llvm.xor %[[CST_0]], %[[VAL_40]] : i32
51-
// CHECK: %[[VAL_42:.*]] = llvm.xor %[[CST_0]], %[[VAL_41]] : i32
52-
// CHECK: %[[VAL_43:.*]] = llvm.xor %[[VAL_38]], %[[CST_0]] : i32
53-
// CHECK: %[[VAL_44:.*]] = llvm.lshr %[[VAL_43]], %[[CST_8]] : i32
54-
// CHECK: %[[VAL_45:.*]] = llvm.shl %[[VAL_44]], %[[CST_3]] : i32
55-
// CHECK: %[[offset:.*]] = llvm.add %[[VAL_45]], %[[VAL_43]] : i32
56-
// CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
57-
// CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<1xf16>
58-
59-
// COM: Because the values per thread of DPAS layout is not contiguous. The values are stored in the SLM in a non-vectorized way.
60-
// COM: Total 64 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16) = 64
61-
// CHECK: llvm.store %[[VAL_66]], %[[VAL_65]] : vector<1xf16>, !llvm.ptr<3>
62-
// CHECK-COUNT-63: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3>
48+
// CHECK: %[[VAL_38:.*]] = llvm.and %[[VAL_28]], %[[CST_12]] : i32
49+
// CHECK: %[[VAL_39:.*]] = llvm.lshr %[[VAL_38]], %[[CST_0]] : i32
50+
// CHECK: %[[VAL_40:.*]] = llvm.xor %[[VAL_37]], %[[VAL_39]] : i32
51+
// CHECK: %[[VAL_41:.*]] = llvm.and %[[VAL_28]], %[[CST_64]] : i32
52+
// CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32
53+
// CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8192]] : i1, i32
54+
// CHECK: %[[VAL_44:.*]] = llvm.xor %[[VAL_40]], %[[VAL_43]] : i32
55+
// CHECK: %[[VAL_45:.*]] = llvm.xor %[[CST_0]], %[[VAL_44]] : i32
56+
// CHECK: %[[VAL_46:.*]] = llvm.mul %[[CST_0]], %[[CST_2]] : i32
57+
// CHECK: %[[VAL_47:.*]] = llvm.xor %[[VAL_45]], %[[VAL_46]] : i32
58+
// CHECK: %[[VAL_48:.*]] = llvm.xor %[[VAL_47]], %[[CST_0]] : i32
59+
// CHECK: %[[offset:.*]] = llvm.add %[[VAL_48]], %[[CST_0]] : i32
60+
// CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
61+
// CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<2xf16>
62+
// CHECK: %[[VAL_67:.*]] = llvm.insertelement {{.*}}, %[[VAL_66]]{{\[}}%[[CST_1]] : i32] : vector<2xf16>
63+
64+
// COM: Because the values per thread of DPAS layout is contiguous. The values are stored in the SLM in vectorized way.
65+
// COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16*2) = 32
66+
// CHECK: llvm.store %[[VAL_67]], %[[VAL_65]] : vector<2xf16>, !llvm.ptr<3>
67+
// CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<2xf16>, !llvm.ptr<3>
6368
// CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
6469

6570
// COM: Because the values per thread of blocked layout is contiguous. The values are loaded from the SLM in a vectorized way.
6671
// COM: Total 8 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*8) = 8
67-
// CHECK-COUNT-8: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
72+
// CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16>
6873

6974
%93 = ttg.convert_layout %cst {allocation.offset = 0 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked>
7075
%80 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
@@ -90,11 +95,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
9095
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
9196
// CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32
9297
// CHECK-DAG: %[[CST_4096:.*]] = llvm.mlir.constant(4096 : i32) : i32
98+
// CHECK-DAG: %[[CST_387:.*]] = llvm.mlir.constant(387 : i32) : i32
9399
// CHECK-DAG: %[[CST_384:.*]] = llvm.mlir.constant(384 : i32) : i32
94-
// CHECK-DAG: %[[CST_112:.*]] = llvm.mlir.constant(112 : i32) : i32
100+
// CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
101+
// CHECK-DAG: %[[CST_48:.*]] = llvm.mlir.constant(48 : i32) : i32
95102
// CHECK-DAG: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32
96-
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
97-
// CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32
103+
// CHECK-DAG: %[[CST_14:.*]] = llvm.mlir.constant(14 : i32) : i32
104+
// CHECK-DAG: %[[CST_12:.*]] = llvm.mlir.constant(12 : i32) : i32
98105
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
99106
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
100107
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -113,43 +120,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
113120
// CHECK: %[[VAL_26:.*]] = llvm.or %[[CST_0]], %[[VAL_25]] : i32
114121
// CHECK: %[[VAL_27:.*]] = llvm.shl %[[warpId]], %[[CST_4]] : i32
115122
// CHECK: %[[VAL_28:.*]] = llvm.or %[[VAL_26]], %[[VAL_27]] : i32
116-
// CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_384]] : i32
117-
// CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_5]] : i32
123+
// CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_3]] : i32
124+
// CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_14]] : i32
118125
// CHECK: %[[VAL_31:.*]] = llvm.xor %[[CST_0]], %[[VAL_30]] : i32
119-
// CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_112]] : i32
120-
// CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_1]] : i32
126+
// CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_387]] : i32
127+
// CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_4]] : i32
121128
// CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_31]], %[[VAL_33]] : i32
122-
// CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_15]] : i32
123-
// CHECK: %[[VAL_36:.*]] = llvm.lshr %[[VAL_35]], %[[CST_0]] : i32
129+
// CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_48]] : i32
130+
// CHECK: %[[VAL_36:.*]] = llvm.shl %[[VAL_35]], %[[CST_1]] : i32
124131
// CHECK: %[[VAL_37:.*]] = llvm.xor %[[VAL_34]], %[[VAL_36]] : i32
125-
// CHECK: %[[VAL_38:.*]] = llvm.xor %[[CST_0]], %[[VAL_37]] : i32
126-
// CHECK: %[[VAL_39:.*]] = llvm.and %[[VAL_28]], %[[CST_511]] : i32
127-
// CHECK: %[[VAL_40:.*]] = llvm.shl %[[VAL_39]], %[[CST_3]] : i32
128-
// CHECK: %[[VAL_41:.*]] = llvm.xor %[[CST_0]], %[[VAL_40]] : i32
129-
// CHECK: %[[VAL_42:.*]] = llvm.xor %[[CST_0]], %[[VAL_41]] : i32
130-
// CHECK: %[[VAL_43:.*]] = llvm.xor %[[VAL_38]], %[[CST_0]] : i32
131-
// CHECK: %[[VAL_44:.*]] = llvm.lshr %[[VAL_43]], %[[CST_8]] : i32
132-
// CHECK: %[[VAL_45:.*]] = llvm.shl %[[VAL_44]], %[[CST_3]] : i32
133-
// CHECK: %[[offset:.*]] = llvm.add %[[VAL_45]], %[[VAL_43]] : i32
134-
// CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
135-
// CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<1xf16>
136-
137-
// COM: Because the values per thread of DPAS layout is not contiguous. The values are stored in the SLM in a non-vectorized way.
138-
// COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 64*256/(4*8*16) = 32
139-
// CHECK: llvm.store %[[VAL_66]], %[[VAL_65]] : vector<1xf16>, !llvm.ptr<3>
140-
// CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3>
132+
// CHECK: %[[VAL_38:.*]] = llvm.and %[[VAL_28]], %[[CST_12]] : i32
133+
// CHECK: %[[VAL_39:.*]] = llvm.lshr %[[VAL_38]], %[[CST_0]] : i32
134+
// CHECK: %[[VAL_40:.*]] = llvm.xor %[[VAL_37]], %[[VAL_39]] : i32
135+
// CHECK: %[[VAL_41:.*]] = llvm.and %[[VAL_28]], %[[CST_64]] : i32
136+
// CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32
137+
// CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8192]] : i1, i32
138+
// CHECK: %[[VAL_44:.*]] = llvm.xor %[[VAL_40]], %[[VAL_43]] : i32
139+
// CHECK: %[[VAL_45:.*]] = llvm.xor %[[CST_0]], %[[VAL_44]] : i32
140+
// CHECK: %[[VAL_46:.*]] = llvm.mul %[[CST_0]], %[[CST_2]] : i32
141+
// CHECK: %[[VAL_47:.*]] = llvm.xor %[[VAL_45]], %[[VAL_46]] : i32
142+
// CHECK: %[[VAL_48:.*]] = llvm.xor %[[VAL_47]], %[[CST_0]] : i32
143+
// CHECK: %[[offset:.*]] = llvm.add %[[VAL_48]], %[[CST_0]] : i32
144+
// CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
145+
// CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<2xf16>
146+
// CHECK: %[[VAL_67:.*]] = llvm.insertelement {{.*}}, %[[VAL_66]]{{\[}}%[[CST_1]] : i32] : vector<2xf16>
147+
148+
// COM: Because the values per thread of DPAS layout is contiguous. The values are stored in the SLM in vectorized way.
149+
// COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16*2) = 32
150+
// CHECK: llvm.store %[[VAL_67]], %[[VAL_65]] : vector<2xf16>, !llvm.ptr<3>
151+
// CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<2xf16>, !llvm.ptr<3>
141152
// CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
142153

143154
// COM: Because the values per thread of blocked layout is contiguous. The values are loaded from the SLM in a vectorized way.
144-
// COM: Total 4 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*8) = 8
145-
// CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
146-
147-
// COM: The 2nd round of exchanging values.
148-
// CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
149-
// CHECK-COUNT-32: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3>
150-
// CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> ()
151-
// CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
152-
155+
// COM: Total 16 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*4) = 16
156+
// CHECK-COUNT-16: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16>
153157
%93 = ttg.convert_layout %cst {allocation.offset = 0 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked>
154158
%80 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
155159
%83 = tt.broadcast %80 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x256x!tt.ptr<f16>, #blocked>

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -816,12 +816,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
816816
// CHECK-LABEL: convert_layout_dpas_block
817817
tt.func @convert_layout_dpas_blocked(%arg0: tensor<32x16xf32, #dpas>) {
818818
// CHECK: llvm.store
819-
// CHECK-SAME: vector<1xf32>, !llvm.ptr<3>
819+
// CHECK-SAME: vector<2xf32>, !llvm.ptr<3>
820820
// CHECK: llvm.store
821-
// CHECK-SAME: vector<1xf32>, !llvm.ptr<3>
821+
// CHECK-SAME: vector<2xf32>, !llvm.ptr<3>
822822
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
823823
// CHECK: llvm.load
824-
// CHECK-SAME: !llvm.ptr<3> -> vector<4xf32>
824+
// CHECK-SAME: !llvm.ptr<3> -> vector<2xf32>
825825
%0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #dpas> -> tensor<32x16xf32, #blocked0>
826826
tt.return
827827
}
@@ -836,13 +836,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
836836
// CHECK-LABEL: convert_layout_dpas_block
837837
tt.func @convert_layout_dpas_blocked(%arg0: tensor<32x64xf32, #dpas>) {
838838
// CHECK: llvm.store
839-
// CHECK-SAME: vector<1xf32>, !llvm.ptr<3>
839+
// CHECK-SAME: vector<4xf32>, !llvm.ptr<3>
840840
// CHECK: llvm.store
841-
// CHECK-SAME: vector<1xf32>, !llvm.ptr<3>
841+
// CHECK-SAME: vector<4xf32>, !llvm.ptr<3>
842842
// CHECK: llvm.store
843-
// CHECK-SAME: vector<1xf32>, !llvm.ptr<3>
843+
// CHECK-SAME: vector<4xf32>, !llvm.ptr<3>
844844
// CHECK: llvm.store
845-
// CHECK-SAME: vector<1xf32>, !llvm.ptr<3>
845+
// CHECK-SAME: vector<4xf32>, !llvm.ptr<3>
846846
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
847847
// CHECK: llvm.load
848848
// CHECK-SAME: !llvm.ptr<3> -> vector<4xf32>
@@ -858,9 +858,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
858858
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
859859
// CHECK-LABEL: convert_layout_dpas_transpose
860860
tt.func @convert_layout_dpas_transpose(%arg0: tensor<128x256xf8E5M2, #dpas>) {
861-
// CHECK-COUNT-128: llvm.store %{{.*}} : vector<1xi8>, !llvm.ptr<3>
861+
// CHECK-COUNT-16: llvm.store %{{.*}} : vector<16xi8>, !llvm.ptr<3>
862862
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
863-
// CHECK-COUNT-80: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xi8>
863+
// CHECK-COUNT-2: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8>
864864
%0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #dpas> -> tensor<128x256xf8E5M2, #blocked>
865865
tt.return
866866
}
@@ -902,7 +902,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
902902
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
903903
// CHECK-LABEL: convert_blocked1d_to_slice1
904904
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
905-
// CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<3>
905+
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xi32>
906906
%cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
907907
tt.return
908908
}

0 commit comments

Comments
 (0)