Skip to content

Commit eee8805

Browse files
committed
remove tensor casting
1 parent be6a119 commit eee8805

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17991799

18001800
static MemRefType computeCollapsedType(
18011801
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
1802+
static MemRefType
1803+
inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
1804+
static MemRefType
1805+
inferCollapsedType(MemRefType type,
1806+
SmallVector<ReassociationIndices> reassociation);
18021807
}];
18031808

18041809
let hasVerifier = 1;

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1919
#include "mlir/Dialect/Linalg/Utils/Utils.h"
20+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2021
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
2122
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2223
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -410,13 +411,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
410411
stripMinedType.cast<RankedTensorType>(),
411412
packingMetadata.reassociations);
412413
} else if (stripMinedType.isa<MemRefType>()) {
413-
auto memrefTy = stripMinedType.cast<MemRefType>();
414-
auto tensorTy =
415-
RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
416-
auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
417-
tensorTy, packingMetadata.reassociations);
418-
collapsedType = MemRefType::get(collapsedTensorType.getShape(),
419-
collapsedTensorType.getElementType());
414+
collapsedType = memref::CollapseShapeOp::inferCollapsedType(
415+
stripMinedType.cast<MemRefType>(), packingMetadata.reassociations);
420416
}
421417

422418
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/Arith/Utils/Utils.h"
1111
#include "mlir/Dialect/MemRef/IR/MemRef.h"
12+
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1213
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1314
#include "mlir/IR/AffineMap.h"
1415
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
11241125
}
11251126
} // else dim.getIndex is a block argument to reshape->getBlock and
11261127
// dominates reshape
1127-
} // Check condition 2
1128+
} // Check condition 2
11281129
else if (dim->getBlock() != reshape->getBlock() &&
11291130
!dim.getIndex().getParentRegion()->isProperAncestor(
11301131
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
25252526
srcType.getMemorySpace());
25262527
}
25272528

2529+
MemRefType
2530+
CollapseShapeOp::inferCollapsedType(MemRefType type,
2531+
ArrayRef<AffineMap> reassociation) {
2532+
auto shape = type.getShape();
2533+
SmallVector<int64_t, 4> newShape;
2534+
assert(isReassociationValid(reassociation) && "invalid reassociation");
2535+
unsigned currentDim = 0;
2536+
for (AffineMap m : reassociation) {
2537+
unsigned dim = m.getNumResults();
2538+
auto band = shape.slice(currentDim, dim);
2539+
int64_t size = 1;
2540+
if (llvm::is_contained(band, ShapedType::kDynamic))
2541+
size = ShapedType::kDynamic;
2542+
else
2543+
for (unsigned d = 0; d < dim; ++d)
2544+
size *= shape[currentDim + d];
2545+
newShape.push_back(size);
2546+
currentDim += dim;
2547+
}
2548+
return MemRefType::get(newShape, type.getElementType());
2549+
}
2550+
2551+
MemRefType CollapseShapeOp::inferCollapsedType(
2552+
MemRefType type, SmallVector<ReassociationIndices> reassociation) {
2553+
return inferCollapsedType(
2554+
type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
2555+
type.getContext(), reassociation)));
2556+
}
2557+
25282558
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
25292559
ArrayRef<ReassociationIndices> reassociation,
25302560
ArrayRef<NamedAttribute> attrs) {

0 commit comments

Comments
 (0)