Skip to content

Commit a72eed7

Browse files
authored
[mlir][spirv] Handle scalar shuffles in vector to spirv conversion (llvm#98809)
These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`.
1 parent 3ccda93 commit a72eed7

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
521521
LogicalResult
522522
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
523523
ConversionPatternRewriter &rewriter) const override {
524-
auto oldResultType = shuffleOp.getResultVectorType();
524+
VectorType oldResultType = shuffleOp.getResultVectorType();
525525
Type newResultType = getTypeConverter()->convertType(oldResultType);
526526
if (!newResultType)
527527
return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
532532
return cast<IntegerAttr>(attr).getValue().getZExtValue();
533533
});
534534

535-
auto oldV1Type = shuffleOp.getV1VectorType();
536-
auto oldV2Type = shuffleOp.getV2VectorType();
535+
VectorType oldV1Type = shuffleOp.getV1VectorType();
536+
VectorType oldV2Type = shuffleOp.getV2VectorType();
537537

538-
// When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
539-
if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
538+
// When both operands and the result are SPIR-V vectors, emit a SPIR-V
539+
// shuffle.
540+
if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
541+
oldResultType.getNumElements() > 1) {
540542
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
541543
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
542544
rewriter.getI32ArrayAttr(mask));
543545
return success();
544546
}
545547

546-
// When at least one of the operands becomes a scalar after type conversion
547-
// for SPIR-V, extract all the required elements and construct the result
548-
// vector.
548+
// When at least one of the operands or the result becomes a scalar after
549+
// type conversion for SPIR-V, extract all the required elements and
550+
// construct the result vector.
549551
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
550552
Value scalarOrVec, int32_t idx) -> Value {
551553
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
569571
newOperand = getElementAtIdx(vec, elementIdx);
570572
}
571573

574+
// Handle the scalar result corner case.
575+
if (newOperands.size() == 1) {
576+
rewriter.replaceOp(shuffleOp, newOperands.front());
577+
return success();
578+
}
579+
572580
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
573581
shuffleOp, newResultType, newOperands);
574-
575582
return success();
576583
}
577584
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
483483

484484
// -----
485485

486+
// CHECK-LABEL: func @shuffle
487+
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
488+
// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
489+
// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
490+
// CHECK: return %[[RES]] : vector<1xi32>
491+
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
492+
%shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
493+
return %shuffle : vector<1xi32>
494+
}
495+
496+
// -----
497+
498+
// CHECK-LABEL: func @shuffle
499+
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
500+
// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
501+
// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
502+
// CHECK: return %[[RES]] : vector<1xi32>
503+
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
504+
%shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
505+
return %shuffle : vector<1xi32>
506+
}
507+
508+
// -----
509+
486510
// CHECK-LABEL: func @interleave
487511
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
488512
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

0 commit comments

Comments
 (0)