@@ -744,51 +744,59 @@ static mlir::Value makeSubview(mlir::OpBuilder &builder, mlir::Location loc,
744
744
assert (dstRank > 0 );
745
745
assert (dstRank <= srcRank);
746
746
747
- bool useReduceRank = true ;
748
-
749
747
mlir::Value view;
750
748
if (isTensor) {
751
- auto resType =
752
- [&]() {
753
- auto tensorType = srcType.cast <mlir::RankedTensorType>();
754
- if (srcRank == dstRank || useReduceRank)
755
- return mlir::tensor::ExtractSliceOp::inferResultType (
756
- tensorType, offsets, sizes, strides);
757
-
758
- return mlir::tensor::ExtractSliceOp::inferRankReducedResultType (
759
- dstRank, tensorType, offsets, sizes, strides);
760
- }()
761
- .cast <mlir::RankedTensorType>();
749
+ auto tensorType = srcType.cast <mlir::RankedTensorType>();
750
+ auto resType = mlir::tensor::ExtractSliceOp::inferResultType (
751
+ tensorType, offsets, sizes, strides)
752
+ .cast <mlir::RankedTensorType>();
762
753
763
754
view = builder.create <mlir::tensor::ExtractSliceOp>(
764
755
loc, resType, src, offsets, sizes, strides);
756
+
757
+ if (srcRank != dstRank) {
758
+ llvm::SmallVector<mlir::OpFoldResult> newOfsets (srcRank,
759
+ builder.getIndexAttr (0 ));
760
+ llvm::SmallVector<mlir::OpFoldResult> newStrides (srcRank,
761
+ builder.getIndexAttr (1 ));
762
+ auto viewType = view.getType ().cast <mlir::RankedTensorType>();
763
+ auto reducedType =
764
+ mlir::tensor::ExtractSliceOp::inferRankReducedResultType (
765
+ dstRank, viewType, newOfsets, sizes, newStrides)
766
+ .cast <mlir::RankedTensorType>();
767
+ view = builder.create <mlir::tensor::ExtractSliceOp>(
768
+ loc, reducedType, view, newOfsets, sizes, newStrides);
769
+ }
765
770
} else {
766
- auto resType =
767
- [&]() {
768
- auto memrefType = srcType.cast <mlir::MemRefType>();
769
- if (srcRank == dstRank || useReduceRank)
770
- return mlir::memref::SubViewOp::inferResultType (memrefType, offsets,
771
- sizes, strides);
772
-
773
- return mlir::memref::SubViewOp::inferRankReducedResultType (
774
- dstRank, memrefType, offsets, sizes, strides);
775
- }()
776
- .cast <mlir::MemRefType>();
771
+ auto memrefType = srcType.cast <mlir::MemRefType>();
772
+ auto resType = mlir::memref::SubViewOp::inferResultType (memrefType, offsets,
773
+ sizes, strides)
774
+ .cast <mlir::MemRefType>();
777
775
778
776
view = builder.create <mlir::memref::SubViewOp>(loc, resType, src, offsets,
779
777
sizes, strides);
780
778
779
+ if (srcRank != dstRank) {
780
+ llvm::SmallVector<mlir::OpFoldResult> newOfsets (srcRank,
781
+ builder.getIndexAttr (0 ));
782
+ llvm::SmallVector<mlir::OpFoldResult> newStrides (srcRank,
783
+ builder.getIndexAttr (1 ));
784
+ auto viewType = view.getType ().cast <mlir::MemRefType>();
785
+ auto reducedType = mlir::memref::SubViewOp::inferRankReducedResultType (
786
+ dstRank, viewType, newOfsets, sizes, newStrides)
787
+ .cast <mlir::MemRefType>();
788
+ view = builder.create <mlir::memref::SubViewOp>(
789
+ loc, reducedType, view, newOfsets, sizes, newStrides);
790
+ resType = reducedType;
791
+ }
792
+
781
793
auto flatMemrefType =
782
794
mlir::MemRefType::get (resType.getShape (), resType.getElementType ());
795
+
783
796
if (resType != flatMemrefType)
784
797
view = builder.create <plier::ChangeLayoutOp>(loc, flatMemrefType, view);
785
798
}
786
799
787
- if (srcRank != dstRank && useReduceRank) {
788
- llvm::SmallVector<int32_t > mapping (dimIndices.begin (), dimIndices.end ());
789
- view = builder.createOrFold <plier::ReduceRankOp>(loc, view, mapping);
790
- }
791
-
792
800
return view;
793
801
}
794
802
0 commit comments