@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
521
521
LogicalResult
522
522
matchAndRewrite (vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
523
523
ConversionPatternRewriter &rewriter) const override {
524
- auto oldResultType = shuffleOp.getResultVectorType ();
524
+ VectorType oldResultType = shuffleOp.getResultVectorType ();
525
525
Type newResultType = getTypeConverter ()->convertType (oldResultType);
526
526
if (!newResultType)
527
527
return rewriter.notifyMatchFailure (shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
532
532
return cast<IntegerAttr>(attr).getValue ().getZExtValue ();
533
533
});
534
534
535
- auto oldV1Type = shuffleOp.getV1VectorType ();
536
- auto oldV2Type = shuffleOp.getV2VectorType ();
535
+ VectorType oldV1Type = shuffleOp.getV1VectorType ();
536
+ VectorType oldV2Type = shuffleOp.getV2VectorType ();
537
537
538
- // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
539
- if (oldV1Type.getNumElements () > 1 && oldV2Type.getNumElements () > 1 ) {
538
+ // When both operands and the result are SPIR-V vectors, emit a SPIR-V
539
+ // shuffle.
540
+ if (oldV1Type.getNumElements () > 1 && oldV2Type.getNumElements () > 1 &&
541
+ oldResultType.getNumElements () > 1 ) {
540
542
rewriter.replaceOpWithNewOp <spirv::VectorShuffleOp>(
541
543
shuffleOp, newResultType, adaptor.getV1 (), adaptor.getV2 (),
542
544
rewriter.getI32ArrayAttr (mask));
543
545
return success ();
544
546
}
545
547
546
- // When at least one of the operands becomes a scalar after type conversion
547
- // for SPIR-V, extract all the required elements and construct the result
548
- // vector.
548
+ // When at least one of the operands or the result becomes a scalar after
549
+ // type conversion for SPIR-V, extract all the required elements and
550
+ // construct the result vector.
549
551
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc ()](
550
552
Value scalarOrVec, int32_t idx) -> Value {
551
553
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType ()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
569
571
newOperand = getElementAtIdx (vec, elementIdx);
570
572
}
571
573
574
+ // Handle the scalar result corner case.
575
+ if (newOperands.size () == 1 ) {
576
+ rewriter.replaceOp (shuffleOp, newOperands.front ());
577
+ return success ();
578
+ }
579
+
572
580
rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(
573
581
shuffleOp, newResultType, newOperands);
574
-
575
582
return success ();
576
583
}
577
584
};
0 commit comments