Skip to content

Commit 26b4e7f

Browse files
hanhanWhhkit
authored andcommitted
[Integrate] Drop llvm/llvm-project@b4c31dc revert. (iree-org#21851)
It carries a cherry-pick fix that gets the operands from the adaptor: - iree-org/llvm-project@8b88014 Changes: - Update most lit tests to check `vector.from_elements`. - Add unrolling patterns to the final conversion. - Implement n-D `vector::ToElementsOp` lowering, which will be dropped after llvm/llvm-project#156992 is landed. It should be added to all the backends, but somehow only AMDGPU backend needs the pattern. The other backends may address the issue via specialized tiling config + dropping vector unit dim patterns. --------- Signed-off-by: hanhanW <[email protected]> Signed-off-by: Ivan Ho <[email protected]>
1 parent ce4f7ca commit 26b4e7f

13 files changed

+91
-112
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ builtin.module attributes { transform.with_named_sequence } {
232232
}
233233

234234
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch
235-
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00>
236235
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x4x1xf32>
237236
// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
238237
// CHECK: %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32
@@ -241,15 +240,18 @@ builtin.module attributes { transform.with_named_sequence } {
241240
// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x4x1xf32> to vector<16xf32>
242241
// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
243242
// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x4x1xf32>
244-
// CHECK: %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
245243
// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32>
246244
// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
247245
// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
248246
// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x4x1xf32> to vector<16xf32>
249247
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
250248
// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x4x1xf32>
251-
// CHECK: %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
252-
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
249+
// CHECK: %[[R0:.+]]:16 = vector.to_elements %[[R0_CAST]] : vector<4x1x4x1xf32>
250+
// CHECK: %[[R1:.+]]:16 = vector.to_elements %[[R1_CAST]] : vector<4x1x4x1xf32>
251+
// CHECK: %[[INS:.+]] = vector.from_elements
252+
// CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3, %[[R0]]#4, %[[R0]]#5, %[[R0]]#6, %[[R0]]#7, %[[R0]]#8, %[[R0]]#9, %[[R0]]#10, %[[R0]]#11, %[[R0]]#12, %[[R0]]#13, %[[R0]]#14, %[[R0]]#15
253+
// CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3, %[[R1]]#4, %[[R1]]#5, %[[R1]]#6, %[[R1]]#7, %[[R1]]#8, %[[R1]]#9, %[[R1]]#10, %[[R1]]#11, %[[R1]]#12, %[[R1]]#13, %[[R1]]#14, %[[R1]]#15
254+
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
253255
// CHECK: return %[[R]]
254256

255257
// -----
@@ -403,28 +405,23 @@ builtin.module attributes { transform.with_named_sequence } {
403405
}
404406
}
405407

406-
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
407-
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<2x3x4x1x4x1xf32>
408-
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
409-
// CHECK: vector.extract %[[C_SIMT]][0, 0]
410-
// CHECK: amdgpu.mfma
411-
// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
412-
// CHECK: vector.extract %[[C_SIMT]][0, 1]
413-
// CHECK: amdgpu.mfma
414-
// CHECK: %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1]
415-
// CHECK: vector.extract %[[C_SIMT]][0, 2]
416-
// CHECK: amdgpu.mfma
417-
// CHECK: %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [0, 2]
418-
// CHECK: vector.extract %[[C_SIMT]][1, 0]
419-
// CHECK: amdgpu.mfma
420-
// CHECK: %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 0]
421-
// CHECK: vector.extract %[[C_SIMT]][1, 1]
422-
// CHECK: amdgpu.mfma
423-
// CHECK: %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [1, 1]
424-
// CHECK: vector.extract %[[C_SIMT]][1, 2]
425-
// CHECK: amdgpu.mfma
426-
// CHECK: %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [1, 2]
427-
// CHECK: iree_vector_ext.to_simd %[[INS5]]
408+
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
409+
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
410+
// CHECK: vector.extract %[[C_SIMT]][0, 0]
411+
// CHECK: amdgpu.mfma
412+
// CHECK: vector.extract %[[C_SIMT]][0, 1]
413+
// CHECK: amdgpu.mfma
414+
// CHECK: vector.extract %[[C_SIMT]][0, 2]
415+
// CHECK: amdgpu.mfma
416+
// CHECK: vector.extract %[[C_SIMT]][1, 0]
417+
// CHECK: amdgpu.mfma
418+
// CHECK: vector.extract %[[C_SIMT]][1, 1]
419+
// CHECK: amdgpu.mfma
420+
// CHECK: vector.extract %[[C_SIMT]][1, 2]
421+
// CHECK: amdgpu.mfma
422+
// CHECK-COUNT-6: vector.to_elements {{.*}} : vector<4x1x4x1xf32>
423+
// CHECK: %[[INS:.+]] = vector.from_elements
424+
// CHECK: iree_vector_ext.to_simd %[[INS]]
428425

429426
// -----
430427

@@ -495,15 +492,17 @@ builtin.module attributes { transform.with_named_sequence } {
495492
}
496493

497494
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mmt
498-
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x4x1xf32>
499495
// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
500496
// CHECK: vector.extract %[[B_SIMT]][0, 0]
501497
// CHECK: amdgpu.mfma
502-
// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
503498
// CHECK: vector.extract %[[B_SIMT]][1, 0]
504499
// CHECK: amdgpu.mfma
505-
// CHECK: %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1]
506-
// CHECK: iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
500+
// CHECK: %[[R0:.+]]:16 = vector.to_elements %{{.+}} : vector<4x1x4x1xf32>
501+
// CHECK: %[[R1:.+]]:16 = vector.to_elements %{{.+}} : vector<4x1x4x1xf32>
502+
// CHECK: %[[INS:.+]] = vector.from_elements
503+
// CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3, %[[R0]]#4, %[[R0]]#5, %[[R0]]#6, %[[R0]]#7, %[[R0]]#8, %[[R0]]#9, %[[R0]]#10, %[[R0]]#11, %[[R0]]#12, %[[R0]]#13, %[[R0]]#14, %[[R0]]#15
504+
// CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3, %[[R1]]#4, %[[R1]]#5, %[[R1]]#6, %[[R1]]#7, %[[R1]]#8, %[[R1]]#9, %[[R1]]#10, %[[R1]]#11, %[[R1]]#12, %[[R1]]#13, %[[R1]]#14, %[[R1]]#15
505+
// CHECK: iree_vector_ext.to_simd %[[INS]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
507506

508507
// -----
509508

@@ -838,6 +837,7 @@ builtin.module attributes { transform.with_named_sequence } {
838837
// CHECK: %[[B_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
839838
// CHECK: %[[MFMA_1:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_1]] + %[[MFMA_0]]
840839
// CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
840+
// CHECK: %[[MFMA_1_CAST:.*]] = vector.shape_cast %[[MFMA_1]] : vector<4xf32> to vector<1x1x4x1xf32>
841841
// CHECK: %[[B_CAST_2:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
842842
// CHECK: %[[C_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x4x1xf32> to vector<4xf32>
843843
// CHECK: %[[MFMA_2:.*]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST_2]] + %[[C_CAST_1]]
@@ -846,6 +846,10 @@ builtin.module attributes { transform.with_named_sequence } {
846846
// CHECK: %[[MFMA_3:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_3]] + %[[MFMA_2]]
847847
// CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
848848
// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_3]] : vector<4xf32> to vector<1x1x4x1xf32>
849-
// CHECK: %[[B_OUT:.*]] = vector.insert %[[R_CAST]]
849+
// CHECK: %[[R0:.+]]:4 = vector.to_elements %[[MFMA_1_CAST]] : vector<1x1x4x1xf32>
850+
// CHECK: %[[R1:.+]]:4 = vector.to_elements %[[R_CAST]] : vector<1x1x4x1xf32>
851+
// CHECK: %[[B_OUT:.+]] = vector.from_elements
852+
// CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3
853+
// CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3
850854
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x2x1x1x4x1xf32> -> vector<32x32xf32>
851855
// CHECK: return %[[R_SIMD]]

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,12 @@ builtin.module attributes { transform.with_named_sequence } {
150150
}
151151

152152
// CHECK-LABEL: func @inter_subgroup_reduction
153-
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<0.000000e+00> : vector<2xf32>
154153
// Local reduction
155154
// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<2x1x1x1x1x4xf32> to vector<2x1x1xf32>
156155
// Thread reduction
157156
// CHECK: %[[THREAD_RED0:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
158-
// CHECK: %[[THREAD_RED1:.+]] = vector.insert %[[THREAD_RED0]], %[[CST1]] [0] : f32 into vector<2xf32>
159157
// CHECK: %[[THREAD_RED2:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
160-
// CHECK: %[[THREAD_RED3:.+]] = vector.insert %[[THREAD_RED2]], %[[THREAD_RED1]] [1] : f32 into vector<2xf32>
158+
// CHECK: %[[THREAD_RED3:.+]] = vector.from_elements %[[THREAD_RED0]], %[[THREAD_RED2]] : vector<2xf32>
161159
// CHECK: %[[THREAD_RED4:.+]] = vector.shape_cast %[[THREAD_RED3]] : vector<2xf32> to vector<2x1x1xf32>
162160
// Subgroup reduction
163161
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x2xf32, #gpu.address_space<workgroup>>
@@ -177,11 +175,10 @@ builtin.module attributes { transform.with_named_sequence } {
177175
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
178176
// CHECK-DAG: %[[DISTR0:.+]] = vector.extract %[[SG_READ0]][0, 0] : f32 from vector<1x1xf32>
179177
// CHECK-DAG: %[[RED0:.+]] = gpu.subgroup_reduce maximumf %[[DISTR0]] cluster(size = 2, stride = 16) : (f32) -> f32
180-
// CHECK-DAG: %[[INS0:.+]] = vector.insert %[[RED0]], %[[CST1]] [0] : f32 into vector<2xf32>
181178
// CHECK-DAG: %[[DISTR1:.+]] = vector.extract %[[SG_READ1]][0, 0] : f32 from vector<1x1xf32>
182179
// CHECK-DAG: %[[RED1:.+]] = gpu.subgroup_reduce maximumf %[[DISTR1]] cluster(size = 2, stride = 16) : (f32) -> f32
183-
// CHECK-DAG: %[[INS1:.+]] = vector.insert %[[RED1]], %[[INS0]] [1] : f32 into vector<2xf32>
184-
// CHECK-DAG: %[[CAST:.+]] = vector.shape_cast %[[INS1]] : vector<2xf32> to vector<2x1x1xf32>
180+
// CHECK-DAG: %[[INS:.+]] = vector.from_elements %[[RED0]], %[[RED1]] : vector<2xf32>
181+
// CHECK-DAG: %[[CAST:.+]] = vector.shape_cast %[[INS]] : vector<2xf32> to vector<2x1x1xf32>
185182
// CHECK-DAG: arith.maximumf %[[CAST]], %[[ACC]] : vector<2x1x1xf32>
186183

187184
// -----

compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,7 @@ void ConvertToLLVMPass::runOnOperation() {
995995
patterns, /*force32BitVectorIndices=*/false);
996996
vector::populateVectorMaskOpLoweringPatterns(patterns);
997997
vector::populateVectorShapeCastLoweringPatterns(patterns);
998+
vector::populateVectorFromElementsLoweringPatterns(patterns);
998999
// TODO: doubtful that the "default" does what one want here, it is likely
9991000
// better to use shuffle.
10001001
vector::populateVectorTransposeLoweringPatterns(
@@ -1079,6 +1080,7 @@ void ConvertToLLVMPass::runOnOperation() {
10791080
vector::populateVectorStepLoweringPatterns(patterns);
10801081
populateVectorToLLVMConversionPatterns(typeConverter, patterns,
10811082
reassociateFpReductions);
1083+
vector::populateVectorFromElementsLoweringPatterns(patterns);
10821084
ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
10831085
vector::populateVectorTransferLoweringPatterns(patterns,
10841086
/*maxTransferRank=*/1);

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_pack_unpack_tests.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ module {
3535

3636
// CHECK-LABEL: func.func @aligned_generic_pack
3737
// CHECK: %[[IN_0:.+]] = vector.broadcast %{{.+}} : vector<16xf32> to vector<16x16xf32>
38-
// CHECK-COUNT-15: %{{.+}} = vector.insert {{.+}} : vector<16xf32> into vector<16x16xf32>
39-
// CHECK: %[[IN_1:.+]] = vector.insert {{.+}} : vector<16xf32> into vector<16x16xf32>
38+
// CHECK-COUNT-16: %{{.+}} = vector.to_elements {{.+}} : vector<16xf32>
39+
// CHECK: %[[IN_1:.+]] = vector.from_elements {{.+}} : vector<16x16xf32>
4040
// CHECK: %[[T0:.+]] = arith.addf %[[IN_0]], %[[IN_1]] : vector<16x16xf32>
4141
// CHECK: %[[T1:.+]] = arith.minimumf %[[T0]], %{{.+}} : vector<16x16xf32>
4242
// CHECK: %[[T2:.+]] = arith.maximumf %[[T1]], %{{.+}} : vector<16x16xf32>

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_split_reduction_tests.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,5 @@ func.func @split_reduction_double_reduction_unsupported() attributes {hal.execut
274274
}
275275

276276
// CHECK-LABEL: func.func @split_reduction_double_reduction_unsupported()
277-
// CHECK: vector.insert %{{.+}}, %{{.+}} : i32 into vector<4xi32>
277+
// CHECK: vector.from_elements %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : vector<4xi32>
278278
// CHECK-NOT: vector.insert %{{.+}}, %{{.+}} : i32 into vector<1xi32>

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ struct ConvertToNVVMPass final
106106
patterns, options.vectorContractLowering);
107107
vector::populateVectorGatherLoweringPatterns(patterns);
108108
vector::populateVectorMaskOpLoweringPatterns(patterns);
109+
vector::populateVectorFromElementsLoweringPatterns(patterns);
109110
// We currently always use 64 bit indices, thus ensure the bit width of
110111
// the mask compare is consistent.
111112
vector::populateVectorMaskMaterializationPatterns(

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,26 @@ static LogicalResult validateDataTypes(Operation *op,
172172
return success();
173173
}
174174

175+
/// TODO(hanchung): Delete the pattern once it is upstreamed:
176+
/// https://github.com/llvm/llvm-project/pull/156992
177+
struct LowerToElementsPattern : public OpRewritePattern<vector::ToElementsOp> {
178+
using OpRewritePattern::OpRewritePattern;
179+
LogicalResult matchAndRewrite(vector::ToElementsOp op,
180+
PatternRewriter &rewriter) const override {
181+
VectorType vecType = op.getSource().getType();
182+
if (vecType.getRank() == 1 || vecType.getNumScalableDims() > 0) {
183+
return failure();
184+
}
185+
auto vec1DType =
186+
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
187+
Value shapeCast = rewriter.create<vector::ShapeCastOp>(
188+
op.getLoc(), vec1DType, op.getSource());
189+
rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
190+
shapeCast);
191+
return success();
192+
}
193+
};
194+
175195
/// A pass that replaces all occurrences of GPU device operations with their
176196
/// corresponding ROCDL equivalent.
177197
///
@@ -256,6 +276,7 @@ struct ConvertToROCDLPass final
256276
vector::populateVectorInterleaveToShufflePatterns(patterns);
257277
vector::populateVectorContractLoweringPatterns(
258278
patterns, options.vectorContractLowering);
279+
vector::populateVectorFromElementsLoweringPatterns(patterns);
259280
vector::populateVectorGatherLoweringPatterns(patterns);
260281
vector::populateVectorMaskOpLoweringPatterns(patterns);
261282
// We currently always use 64 bit indices, thus ensure the bit width of
@@ -269,6 +290,7 @@ struct ConvertToROCDLPass final
269290
patterns, options.vectorTransposeLowering);
270291
vector::populateVectorTransferLoweringPatterns(patterns);
271292
arith::populateExpandBFloat16Patterns(patterns);
293+
patterns.insert<LowerToElementsPattern>(&getContext());
272294
if (failed(applyPatternsGreedily(m, std::move(patterns)))) {
273295
return signalPassFailure();
274296
}

compiler/src/iree/compiler/Codegen/SPIRV/test/break_down_large_vector.mlir

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// CHECK-LABEL: func @extract_strided_slice_8_elements
44
func.func @extract_strided_slice_8_elements(%input: vector<8xf16>) -> vector<4xf16> {
55
// CHECK-COUNT-4: vector.extract
6-
// CHECK-COUNT-4: vector.insert
6+
// CHECK: vector.from_elements
77
%0 = vector.extract_strided_slice %input {offsets = [1], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
88
return %0: vector<4xf16>
99
}
@@ -22,9 +22,8 @@ func.func @extract_strided_slice_4_elements(%input: vector<4xf16>) -> vector<2xf
2222
// CHECK-LABEL: func @bitcast_16_elements
2323
func.func @bitcast_16_elements(%input: vector<16xi8>) -> vector<4xi32> {
2424
// CHECK-DAG: %[[CST_I32:.*]] = arith.constant dense<0> : vector<4xi32>
25-
// CHECK-DAG: arith.constant dense<0> : vector<4xi8>
2625
// CHECK-COUNT-4: vector.extract
27-
// CHECK-COUNT-4: vector.insert
26+
// CHECK: vector.from_elements
2827
// CHECK: vector.bitcast %{{.*}} : vector<4xi8> to vector<1xi32>
2928
// CHECK: vector.insert_strided_slice {{.*}}, %[[CST_I32]]
3029
// CHECK-COUNT-3: vector.bitcast
@@ -41,28 +40,22 @@ func.func @bitcast_extract_extend_0(%input: vector<1xi32>) -> vector<4xi32> {
4140
return %extend : vector<4xi32>
4241
}
4342

44-
4543
// CHECK-LABEL: func @bitcast_extract_extend_0
4644
// CHECK-SAME: (%[[INPUT:.+]]: vector<1xi32>)
47-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<4xi32>
4845
// CHECK-DAG: %[[MASK:.+]] = arith.constant 15 : i32
4946
// CHECK-DAG: %[[OFF1:.+]] = arith.constant 4 : i32
5047
// CHECK-DAG: %[[OFF2:.+]] = arith.constant 8 : i32
5148
// CHECK-DAG: %[[OFF3:.+]] = arith.constant 12 : i32
5249
// CHECK: %[[BASE:.+]] = vector.extract %[[INPUT]][0] : i32 from vector<1xi32>
5350
// CHECK: %[[AND0:.+]] = arith.andi %[[BASE]], %[[MASK]] : i32
54-
// CHECK: %[[INS0:.+]] = vector.insert %[[AND0]], %[[ZERO]] [0]
5551
// CHECK: %[[SHR1:.+]] = arith.shrui %[[BASE]], %[[OFF1]] : i32
5652
// CHECK: %[[AND1:.+]] = arith.andi %[[SHR1]], %[[MASK]] : i32
57-
// CHECK: %[[INS1:.+]] = vector.insert %[[AND1]], %[[INS0]] [1]
5853
// CHECK: %[[SHR2:.+]] = arith.shrui %[[BASE]], %[[OFF2]] : i32
5954
// CHECK: %[[AND2:.+]] = arith.andi %[[SHR2]], %[[MASK]] : i32
60-
// CHECK: %[[INS2:.+]] = vector.insert %[[AND2]], %[[INS1]] [2]
6155
// CHECK: %[[SHR3:.+]] = arith.shrui %[[BASE]], %[[OFF3]] : i32
6256
// CHECK: %[[AND3:.+]] = arith.andi %[[SHR3]], %[[MASK]] : i32
63-
// CHECK: %[[INS3:.+]] = vector.insert %[[AND3]], %[[INS2]] [3]
64-
// CHECK: return %[[INS3]] : vector<4xi32>
65-
57+
// CHECK: %[[RES:.+]] = vector.from_elements %[[AND0]], %[[AND1]], %[[AND2]], %[[AND3]] : vector<4xi32>
58+
// CHECK: return %[[RES]] : vector<4xi32>
6659

6760
// -----
6861

@@ -75,7 +68,6 @@ func.func @bitcast_extract_extend_1(%input: vector<4xi32>) -> vector<4xi32> {
7568

7669
// CHECK-LABEL: func.func @bitcast_extract_extend_1
7770
// CHECK-SAME: (%[[INPUT:.+]]: vector<4xi32>)
78-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<4xi32>
7971
// CHECK-DAG: %[[MASK:.+]] = arith.constant 15 : i32
8072
// CHECK-DAG: %[[OFF0:.+]] = arith.constant 16 : i32
8173
// CHECK-DAG: %[[OFF1:.+]] = arith.constant 20 : i32
@@ -84,14 +76,11 @@ func.func @bitcast_extract_extend_1(%input: vector<4xi32>) -> vector<4xi32> {
8476
// CHECK: %[[BASE:.+]] = vector.extract %[[INPUT]][2] : i32 from vector<4xi32>
8577
// CHECK: %[[SHR0:.+]] = arith.shrui %[[BASE]], %[[OFF0]] : i32
8678
// CHECK: %[[AND0:.+]] = arith.andi %[[SHR0]], %[[MASK]] : i32
87-
// CHECK: %[[INS0:.+]] = vector.insert %[[AND0]], %[[ZERO]] [0]
8879
// CHECK: %[[SHR1:.+]] = arith.shrui %[[BASE]], %[[OFF1]] : i32
8980
// CHECK: %[[AND1:.+]] = arith.andi %[[SHR1]], %[[MASK]] : i32
90-
// CHECK: %[[INS1:.+]] = vector.insert %[[AND1]], %[[INS0]] [1]
9181
// CHECK: %[[SHR2:.+]] = arith.shrui %[[BASE]], %[[OFF2]] : i32
9282
// CHECK: %[[AND2:.+]] = arith.andi %[[SHR2]], %[[MASK]] : i32
93-
// CHECK: %[[INS2:.+]] = vector.insert %[[AND2]], %[[INS1]] [2]
9483
// CHECK: %[[SHR3:.+]] = arith.shrui %[[BASE]], %[[OFF3]] : i32
9584
// CHECK: %[[AND3:.+]] = arith.andi %[[SHR3]], %[[MASK]] : i32
96-
// CHECK: %[[INS3:.+]] = vector.insert %[[AND3]], %[[INS2]] [3]
97-
// CHECK: return %[[INS3]] : vector<4xi32>
85+
// CHECK: %[[RES:.+]] = vector.from_elements %[[AND0]], %[[AND1]], %[[AND2]], %[[AND3]] : vector<4xi32>
86+
// CHECK: return %[[RES]] : vector<4xi32>

0 commit comments

Comments
 (0)