@@ -600,30 +600,33 @@ static Value createLinalgBodyCalculationForElementwiseOp(
600600static Value expandRank (PatternRewriter &rewriter, Location loc, Value tensor,
601601 int64_t rank) {
602602 // No need to expand if we are already at the desired rank
603- auto shapedType = dyn_cast<ShapedType>(tensor.getType ());
604- assert (shapedType && shapedType.hasRank () && " expected a ranked shaped type" );
605- int64_t numExtraDims = rank - shapedType.getRank ();
603+ auto tensorType = dyn_cast<RankedTensorType>(tensor.getType ());
604+ assert (tensorType && " expected a ranked tensor type" );
605+ int64_t tensorRank = tensorType.getRank ();
606+ int64_t numExtraDims = rank - tensorRank;
606607 assert (numExtraDims >= 0 && " cannot expand tensor to a lower rank" );
607608 if (!numExtraDims)
608609 return tensor;
609610
610611 // Compute reassociation indices
611- SmallVector<SmallVector<int64_t , 2 >> reassociationIndices (
612- shapedType.getRank ());
612+ SmallVector<ReassociationIndices> reassociationIndices (tensorRank);
613613 int64_t index = 0 ;
614- for (index = 0 ; index <= numExtraDims; index++)
615- reassociationIndices[0 ].push_back (index);
616- for (size_t position = 1 ; position < reassociationIndices.size (); position++)
617- reassociationIndices[position].push_back (index++);
614+ if (tensorRank != 0 ) {
615+ for (index = 0 ; index <= numExtraDims; index++)
616+ reassociationIndices[0 ].push_back (index);
617+ for (size_t position = 1 ; position < reassociationIndices.size ();
618+ position++)
619+ reassociationIndices[position].push_back (index++);
620+ }
618621
619622 // Compute result type
620623 SmallVector<int64_t > resultShape;
621624 for (index = 0 ; index < numExtraDims; index++)
622625 resultShape.push_back (1 );
623- for (auto size : shapedType .getShape ())
626+ for (auto size : tensorType .getShape ())
624627 resultShape.push_back (size);
625628 auto resultType =
626- RankedTensorType::get (resultShape, shapedType .getElementType ());
629+ RankedTensorType::get (resultShape, tensorType .getElementType ());
627630
628631 // Emit 'tensor.expand_shape' op
629632 return rewriter.create <tensor::ExpandShapeOp>(loc, resultType, tensor,
0 commit comments