@@ -718,9 +718,123 @@ struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
718
718
};
719
719
} // namespace
720
720
721
+ static Value getCollapsedInitTensor (OpBuilder &builder,
722
+ TensorReshapeOp reshapeOp) {
723
+ Location loc = reshapeOp.getLoc ();
724
+ SmallVector<Value, 4 > dynamicShapes;
725
+ SmallVector<int64_t , 4 > staticShapes;
726
+ auto reassociation = reshapeOp.getReassociationMaps ();
727
+ Value src = reshapeOp.src ();
728
+ RankedTensorType srcType = reshapeOp.getSrcType ();
729
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
730
+ for (auto map : reassociation) {
731
+ Value linearizedDynamicDim = nullptr ;
732
+ int64_t linearizedStaticDim = 1 ;
733
+ for (unsigned i : llvm::map_range (map.getResults (), [](AffineExpr e) {
734
+ return e.cast <AffineDimExpr>().getPosition ();
735
+ })) {
736
+ if (ShapedType::isDynamic (srcShape[i])) {
737
+ Value shapeVal = builder.create <DimOp>(loc, src, i);
738
+ if (linearizedDynamicDim) {
739
+ linearizedDynamicDim =
740
+ builder.create <MulIOp>(loc, linearizedDynamicDim, shapeVal);
741
+ } else {
742
+ linearizedDynamicDim = shapeVal;
743
+ }
744
+ } else {
745
+ linearizedStaticDim *= srcShape[i];
746
+ }
747
+ }
748
+ if (linearizedDynamicDim) {
749
+ if (linearizedStaticDim != 1 ) {
750
+ linearizedDynamicDim = builder.create <MulIOp>(
751
+ loc, linearizedDynamicDim,
752
+ builder.create <ConstantIndexOp>(loc, linearizedStaticDim));
753
+ }
754
+ dynamicShapes.push_back (linearizedDynamicDim);
755
+ staticShapes.push_back (ShapedType::kDynamicSize );
756
+ } else {
757
+ staticShapes.push_back (linearizedStaticDim);
758
+ }
759
+ }
760
+ return builder.create <InitTensorOp>(loc, dynamicShapes, staticShapes,
761
+ srcType.getElementType ());
762
+ }
763
+
764
+ static Value getExpandedInitTensor (OpBuilder &builder,
765
+ TensorReshapeOp reshapeOp) {
766
+ SmallVector<Value, 4 > dynamicShapes;
767
+ SmallVector<int64_t , 4 > staticShapes;
768
+ auto reassociation = reshapeOp.getReassociationMaps ();
769
+ Value src = reshapeOp.src ();
770
+ RankedTensorType srcType = reshapeOp.getSrcType ();
771
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
772
+ ArrayRef<int64_t > dstShape = reshapeOp.getResultType ().getShape ();
773
+ Location loc = reshapeOp.getLoc ();
774
+ for (auto map : enumerate(reassociation)) {
775
+ int64_t linearizedStaticDim = 1 ;
776
+ bool hasDynamic = false ;
777
+ for (unsigned i :
778
+ llvm::map_range (map.value ().getResults (), [](AffineExpr e) {
779
+ return e.cast <AffineDimExpr>().getPosition ();
780
+ })) {
781
+ if (ShapedType::isDynamic (dstShape[i])) {
782
+ // Only one of the dimensions of the expanded shape should be dynamic.
783
+ if (hasDynamic)
784
+ return nullptr ;
785
+ hasDynamic = true ;
786
+ staticShapes.push_back (ShapedType::kDynamicSize );
787
+ continue ;
788
+ }
789
+ staticShapes.push_back (dstShape[i]);
790
+ linearizedStaticDim *= dstShape[i];
791
+ }
792
+ if (hasDynamic) {
793
+ // If the expanded dimensions has a dynamic shape, the src shape must be
794
+ // dynamic as well.
795
+ if (!ShapedType::isDynamic (srcShape[map.index ()]))
796
+ return nullptr ;
797
+ Value dynamicDim = builder.create <DimOp>(loc, src, map.index ());
798
+ if (linearizedStaticDim != 1 ) {
799
+ dynamicDim = builder.create <UnsignedDivIOp>(
800
+ loc, dynamicDim,
801
+ builder.create <ConstantIndexOp>(loc, linearizedStaticDim));
802
+ }
803
+ dynamicShapes.push_back (dynamicDim);
804
+ }
805
+ }
806
+ return builder.create <InitTensorOp>(loc, dynamicShapes, staticShapes,
807
+ srcType.getElementType ());
808
+ }
809
+
810
+ namespace {
811
+ struct FoldWithTensorReshapeOp : public OpRewritePattern <TensorReshapeOp> {
812
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
813
+
814
+ LogicalResult matchAndRewrite (TensorReshapeOp reshapeOp,
815
+ PatternRewriter &rewriter) const override {
816
+ if (!reshapeOp.src ().getDefiningOp <InitTensorOp>())
817
+ return failure ();
818
+ RankedTensorType collapsedType = reshapeOp.getSrcType ();
819
+ RankedTensorType expandedType = reshapeOp.getResultType ();
820
+ bool isCollapsed = expandedType.getRank () < collapsedType.getRank ();
821
+ if (isCollapsed)
822
+ std::swap (collapsedType, expandedType);
823
+ Value initTensorOp = isCollapsed
824
+ ? getCollapsedInitTensor (rewriter, reshapeOp)
825
+ : getExpandedInitTensor (rewriter, reshapeOp);
826
+ if (!initTensorOp)
827
+ return failure ();
828
+ rewriter.replaceOp (reshapeOp, initTensorOp);
829
+ return success ();
830
+ }
831
+ };
832
+ } // namespace
833
+
721
834
void InitTensorOp::getCanonicalizationPatterns (
722
835
OwningRewritePatternList &results, MLIRContext *context) {
723
- results.insert <ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
836
+ results.insert <FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
837
+ ReplaceStaticShapeDims>(context);
724
838
}
725
839
726
840
// ===----------------------------------------------------------------------===//
@@ -1043,23 +1157,23 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
1043
1157
ArrayRef<int64_t > expandedShape = expandedType.getShape ();
1044
1158
unsigned expandedDimStart = 0 ;
1045
1159
for (auto map : llvm::enumerate (op.getReassociationMaps ())) {
1046
- Optional<int64_t > dynamicDims ;
1160
+ Optional<int64_t > dynamicShape ;
1047
1161
int64_t linearizedStaticShape = 1 ;
1048
1162
for (auto dim : llvm::enumerate (expandedShape.slice (
1049
1163
expandedDimStart, map.value ().getNumResults ()))) {
1050
1164
if (ShapedType::isDynamic (dim.value ())) {
1051
- if (isExpandingReshape && dynamicDims ) {
1165
+ if (isExpandingReshape && dynamicShape ) {
1052
1166
return op->emitOpError (" invalid to have a single dimension (" )
1053
1167
<< map.index () << " ) expanded into multiple dynamic dims ("
1054
- << expandedDimStart + dynamicDims .getValue () << " ,"
1168
+ << expandedDimStart + dynamicShape .getValue () << " ,"
1055
1169
<< expandedDimStart + dim.index () << " )" ;
1056
1170
}
1057
- dynamicDims = dim.index ();
1171
+ dynamicShape = dim.index ();
1058
1172
} else {
1059
1173
linearizedStaticShape *= dim.value ();
1060
1174
}
1061
1175
}
1062
- if (dynamicDims ) {
1176
+ if (dynamicShape ) {
1063
1177
if (!ShapedType::isDynamic (collapsedShape[map.index ()])) {
1064
1178
return op->emitOpError (" expected dimension " )
1065
1179
<< map.index ()
0 commit comments