@@ -650,30 +650,33 @@ static Value createLinalgBodyCalculationForElementwiseOp(
650650static Value expandRank (PatternRewriter &rewriter, Location loc, Value tensor,
651651 int64_t rank) {
652652 // No need to expand if we are already at the desired rank
653- auto shapedType = dyn_cast<ShapedType>(tensor.getType ());
654- assert (shapedType && shapedType.hasRank () && " expected a ranked shaped type" );
655- int64_t numExtraDims = rank - shapedType.getRank ();
653+ auto tensorType = dyn_cast<RankedTensorType>(tensor.getType ());
654+ assert (tensorType && " expected a ranked tensor type" );
655+ int64_t tensorRank = tensorType.getRank ();
656+ int64_t numExtraDims = rank - tensorRank;
656657 assert (numExtraDims >= 0 && " cannot expand tensor to a lower rank" );
657658 if (!numExtraDims)
658659 return tensor;
659660
660661 // Compute reassociation indices
661- SmallVector<SmallVector<int64_t , 2 >> reassociationIndices (
662- shapedType.getRank ());
662+ SmallVector<ReassociationIndices> reassociationIndices (tensorRank);
663663 int64_t index = 0 ;
664- for (index = 0 ; index <= numExtraDims; index++)
665- reassociationIndices[0 ].push_back (index);
666- for (size_t position = 1 ; position < reassociationIndices.size (); position++)
667- reassociationIndices[position].push_back (index++);
664+ if (tensorRank != 0 ) {
665+ for (index = 0 ; index <= numExtraDims; index++)
666+ reassociationIndices[0 ].push_back (index);
667+ for (size_t position = 1 ; position < reassociationIndices.size ();
668+ position++)
669+ reassociationIndices[position].push_back (index++);
670+ }
668671
669672 // Compute result type
670673 SmallVector<int64_t > resultShape;
671674 for (index = 0 ; index < numExtraDims; index++)
672675 resultShape.push_back (1 );
673- for (auto size : shapedType .getShape ())
676+ for (auto size : tensorType .getShape ())
674677 resultShape.push_back (size);
675678 auto resultType =
676- RankedTensorType::get (resultShape, shapedType .getElementType ());
679+ RankedTensorType::get (resultShape, tensorType .getElementType ());
677680
678681 // Emit 'tensor.expand_shape' op
679682 return rewriter.create <tensor::ExpandShapeOp>(loc, resultType, tensor,
0 commit comments