Skip to content

Commit 2a9cc20

Browse files
authored
Do not use ReduceRankOp in makeSubview (#151)
1 parent cbacf1c commit 2a9cc20

File tree

3 files changed

+43
-37
lines changed

3 files changed

+43
-37
lines changed

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ mlir::Operation *PlierUtilDialect::materializeConstant(mlir::OpBuilder &builder,
7979
if (mlir::arith::ConstantOp::isBuildableWith(value, type))
8080
return builder.create<mlir::arith::ConstantOp>(loc, type, value);
8181

82+
if (type.isa<mlir::IndexType>())
83+
if (auto val = mlir::getConstantIntValue(value))
84+
return builder.create<mlir::arith::ConstantIndexOp>(loc, *val);
85+
8286
return nullptr;
8387
}
8488

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -744,13 +744,7 @@ struct FlattenSubview : public mlir::OpRewritePattern<mlir::memref::SubViewOp> {
744744
flatSubview = rewriter.createOrFold<mlir::memref::CastOp>(
745745
loc, dstFlatType, flatSubview);
746746

747-
// TODO: bug in ReinterpretCastOp::verify
748-
auto offset =
749-
(mlir::ShapedType::isDynamicStrideOrOffset(resultOffset)
750-
? mlir::OpFoldResult(
751-
rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0)
752-
.getResult())
753-
: mlir::OpFoldResult(rewriter.getIndexAttr(0)));
747+
auto offset = rewriter.getIndexAttr(0);
754748

755749
for (auto i : llvm::seq<size_t>(0, strides.size())) {
756750
if (mlir::ShapedType::isDynamicStrideOrOffset(resultStrides[i])) {
@@ -789,7 +783,7 @@ struct FlattenSubview : public mlir::OpRewritePattern<mlir::memref::SubViewOp> {
789783

790784
auto droppedDims = op.getDroppedDims();
791785
for (auto i : llvm::seq(0u, srcRank)) {
792-
if (droppedDims[i]) {
786+
if (!droppedDims[i]) {
793787
filteredSizes.emplace_back(sizes[i]);
794788
filteredStrides.emplace_back(strides[i]);
795789
}

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -744,51 +744,59 @@ static mlir::Value makeSubview(mlir::OpBuilder &builder, mlir::Location loc,
744744
assert(dstRank > 0);
745745
assert(dstRank <= srcRank);
746746

747-
bool useReduceRank = true;
748-
749747
mlir::Value view;
750748
if (isTensor) {
751-
auto resType =
752-
[&]() {
753-
auto tensorType = srcType.cast<mlir::RankedTensorType>();
754-
if (srcRank == dstRank || useReduceRank)
755-
return mlir::tensor::ExtractSliceOp::inferResultType(
756-
tensorType, offsets, sizes, strides);
757-
758-
return mlir::tensor::ExtractSliceOp::inferRankReducedResultType(
759-
dstRank, tensorType, offsets, sizes, strides);
760-
}()
761-
.cast<mlir::RankedTensorType>();
749+
auto tensorType = srcType.cast<mlir::RankedTensorType>();
750+
auto resType = mlir::tensor::ExtractSliceOp::inferResultType(
751+
tensorType, offsets, sizes, strides)
752+
.cast<mlir::RankedTensorType>();
762753

763754
view = builder.create<mlir::tensor::ExtractSliceOp>(
764755
loc, resType, src, offsets, sizes, strides);
756+
757+
if (srcRank != dstRank) {
758+
llvm::SmallVector<mlir::OpFoldResult> newOfsets(srcRank,
759+
builder.getIndexAttr(0));
760+
llvm::SmallVector<mlir::OpFoldResult> newStrides(srcRank,
761+
builder.getIndexAttr(1));
762+
auto viewType = view.getType().cast<mlir::RankedTensorType>();
763+
auto reducedType =
764+
mlir::tensor::ExtractSliceOp::inferRankReducedResultType(
765+
dstRank, viewType, newOfsets, sizes, newStrides)
766+
.cast<mlir::RankedTensorType>();
767+
view = builder.create<mlir::tensor::ExtractSliceOp>(
768+
loc, reducedType, view, newOfsets, sizes, newStrides);
769+
}
765770
} else {
766-
auto resType =
767-
[&]() {
768-
auto memrefType = srcType.cast<mlir::MemRefType>();
769-
if (srcRank == dstRank || useReduceRank)
770-
return mlir::memref::SubViewOp::inferResultType(memrefType, offsets,
771-
sizes, strides);
772-
773-
return mlir::memref::SubViewOp::inferRankReducedResultType(
774-
dstRank, memrefType, offsets, sizes, strides);
775-
}()
776-
.cast<mlir::MemRefType>();
771+
auto memrefType = srcType.cast<mlir::MemRefType>();
772+
auto resType = mlir::memref::SubViewOp::inferResultType(memrefType, offsets,
773+
sizes, strides)
774+
.cast<mlir::MemRefType>();
777775

778776
view = builder.create<mlir::memref::SubViewOp>(loc, resType, src, offsets,
779777
sizes, strides);
780778

779+
if (srcRank != dstRank) {
780+
llvm::SmallVector<mlir::OpFoldResult> newOfsets(srcRank,
781+
builder.getIndexAttr(0));
782+
llvm::SmallVector<mlir::OpFoldResult> newStrides(srcRank,
783+
builder.getIndexAttr(1));
784+
auto viewType = view.getType().cast<mlir::MemRefType>();
785+
auto reducedType = mlir::memref::SubViewOp::inferRankReducedResultType(
786+
dstRank, viewType, newOfsets, sizes, newStrides)
787+
.cast<mlir::MemRefType>();
788+
view = builder.create<mlir::memref::SubViewOp>(
789+
loc, reducedType, view, newOfsets, sizes, newStrides);
790+
resType = reducedType;
791+
}
792+
781793
auto flatMemrefType =
782794
mlir::MemRefType::get(resType.getShape(), resType.getElementType());
795+
783796
if (resType != flatMemrefType)
784797
view = builder.create<plier::ChangeLayoutOp>(loc, flatMemrefType, view);
785798
}
786799

787-
if (srcRank != dstRank && useReduceRank) {
788-
llvm::SmallVector<int32_t> mapping(dimIndices.begin(), dimIndices.end());
789-
view = builder.createOrFold<plier::ReduceRankOp>(loc, view, mapping);
790-
}
791-
792800
return view;
793801
}
794802

0 commit comments

Comments
 (0)