Skip to content

Commit 0c184df

Browse files
committed
init
1 parent 226230c commit 0c184df

File tree

6 files changed

+57
-25
lines changed

6 files changed

+57
-25
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
4343
code commonExtraClassDeclaration = [{
4444
size_t getSourceRank() { return getSourceType().getRank(); };
4545
size_t getDestRank() { return getDestType().getRank(); };
46-
RankedTensorType getSourceType() {
47-
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
48-
RankedTensorType getDestType() {
49-
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
46+
ShapedType getSourceType() {
47+
return ::llvm::cast<ShapedType>(getSource().getType()); };
48+
ShapedType getDestType() {
49+
return ::llvm::cast<ShapedType>(getDest().getType()); };
5050

5151
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5252

@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
190190
// Method to get the `RankedTensorType` of the result based on the inner
191191
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
192192
// of outer loops (outerDimsPerm).
193-
static RankedTensorType inferPackedType(RankedTensorType sourceType,
193+
static RankedTensorType inferPackedType(ShapedType sourceType,
194194
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
195195
ArrayRef<int64_t> outerDimsPerm = {});
196196

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
803803
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
804804
}
805805

806-
RankedTensorType srcPadType = srcPadOp.getSourceType();
806+
ShapedType srcPadType = srcPadOp.getSourceType();
807807
SmallVector<OpFoldResult, 4> newSizes;
808808
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
809809
if (srcPadType.isDynamicDim(i)) {
@@ -4433,7 +4433,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44334433
return op->emitError("invalid zero tile factor");
44344434

44354435
// Verify inner_dims_pos and outer_dims_perm.
4436-
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4436+
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
44374437
? packOrUnPack.getSourceType()
44384438
: packOrUnPack.getDestType();
44394439
size_t unpackedRank = unpackedType.getRank();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
47474747

47484748
/// Get the expected packed type based on source type, tile factors, position of
47494749
/// the inner tiles and permutation of the outer tiled loop.
4750-
RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4750+
RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
47514751
ArrayRef<int64_t> innerTileSizes,
47524752
ArrayRef<int64_t> innerDimsPos,
47534753
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
49434943
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
49444944
}
49454945
Value dest = packOp.getDest();
4946-
RankedTensorType originalResultType = packOp.getDestType();
4946+
ShapedType originalResultType = packOp.getDestType();
49474947
bool needUpdateDestType = (destShape != originalResultType.getShape());
49484948
if (needUpdateDestType) {
49494949
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
49534953
rewriter.modifyOpInPlace(packOp, [&] {
49544954
packOp.getSourceMutable().assign(source);
49554955
packOp.getDestMutable().assign(dest);
4956-
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4956+
packOp.getResult().setType(cast<ShapedType>(dest.getType()));
49574957
});
49584958
// Insert a cast if needed
49594959
if (needUpdateDestType) {
@@ -4970,7 +4970,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
49704970

49714971
template <typename PackOrUnpackOp>
49724972
static bool isLikePadUnPad(PackOrUnpackOp packOp,
4973-
RankedTensorType packedTensorType) {
4973+
ShapedType packedTensorType) {
49744974
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
49754975
std::is_same<PackOrUnpackOp, UnPackOp>::value,
49764976
"Function meant for pack/unpack");
@@ -5274,7 +5274,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
52745274
}
52755275

52765276
bool UnPackOp::isLikeUnPad() {
5277-
RankedTensorType packedTensorType = getSourceType();
5277+
ShapedType packedTensorType = getSourceType();
52785278
return isLikePadUnPad(*this, packedTensorType);
52795279
}
52805280

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
111111
if (packOp.getPaddingValue())
112112
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113113

114-
RankedTensorType sourceType = packOp.getSourceType();
114+
ShapedType sourceType = packOp.getSourceType();
115115
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116116
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117117
packOp.getStaticTiles())) &&
118118
!packOp.isLikePad()) {
119119
return failure();
120120
}
121121

122-
RankedTensorType destType = packOp.getDestType();
122+
ShapedType destType = packOp.getDestType();
123123
auto reassociation =
124124
getReassociationIndicesForReshape(sourceType, destType);
125125
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
157157
"expects outer_dims_perm is empty or an identity permutation");
158158
}
159159

160-
RankedTensorType sourceType = unpackOp.getSourceType();
161-
RankedTensorType destType = unpackOp.getDestType();
160+
ShapedType sourceType = unpackOp.getSourceType();
161+
ShapedType destType = unpackOp.getDestType();
162162
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163163
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164164

@@ -173,15 +173,15 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
173173

174174
LogicalResult matchAndRewrite(UnPackOp unpackOp,
175175
PatternRewriter &rewriter) const override {
176-
RankedTensorType destType = unpackOp.getDestType();
176+
ShapedType destType = unpackOp.getDestType();
177177
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178178
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179179
unpackOp.getStaticTiles())) &&
180180
!unpackOp.isLikeUnPad()) {
181181
return failure();
182182
}
183183

184-
RankedTensorType sourceType = unpackOp.getSourceType();
184+
ShapedType sourceType = unpackOp.getSourceType();
185185
auto reassociation =
186186
getReassociationIndicesForReshape(sourceType, destType);
187187
if (!reassociation)

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
359359
OpBuilder::InsertionGuard g(rewriter);
360360
rewriter.setInsertionPoint(unPackOp);
361361

362-
RankedTensorType packedTensorType = unPackOp.getSourceType();
362+
ShapedType packedTensorType = unPackOp.getSourceType();
363363
int64_t packedRank = packedTensorType.getRank();
364364

365365
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,18 +396,37 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
396396
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
397397

398398
// 3. Transpose packedShape to stripMinedShape.
399-
RankedTensorType stripMinedTensorType =
400-
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
401-
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
402-
stripMinedTensorType, packingMetadata.reassociations);
399+
ShapedType stripMinedType;
400+
if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
401+
stripMinedType =
402+
RankedTensorType::get(stripMinedShape, tensorType.getElementType());
403+
} else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
404+
stripMinedType =
405+
MemRefType::get(stripMinedShape, memrefType.getElementType());
406+
}
407+
ShapedType collapsedType;
408+
if (stripMinedType.isa<TensorType>()) {
409+
collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
410+
stripMinedType.cast<RankedTensorType>(),
411+
packingMetadata.reassociations);
412+
} else if (stripMinedType.isa<MemRefType>()) {
413+
auto memrefTy = stripMinedType.cast<MemRefType>();
414+
auto tensorTy =
415+
RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
416+
auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
417+
tensorTy, packingMetadata.reassociations);
418+
// tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
419+
collapsedType = MemRefType::get(collapsedTensorType.getShape(),
420+
collapsedTensorType.getElementType());
421+
}
403422

404423
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
405424
// permutation.
406425
SmallVector<OpFoldResult, 4> dims =
407426
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
408427
applyPermutationToVector(dims, packedToStripMinedShapePerm);
409428
auto emptyOp = rewriter.create<tensor::EmptyOp>(
410-
loc, dims, stripMinedTensorType.getElementType());
429+
loc, dims, stripMinedType.getElementType());
411430
auto transposeOp = rewriter.create<linalg::TransposeOp>(
412431
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413432

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
16691669
OpBuilder::InsertionGuard g(rewriter);
16701670
rewriter.setInsertionPoint(unpackOp);
16711671

1672-
RankedTensorType unpackTensorType = unpackOp.getSourceType();
1672+
ShapedType unpackTensorType = unpackOp.getSourceType();
16731673

16741674
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
16751675
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"version": "0.2.0",
3+
"configurations": [
4+
{
5+
"name": "ma",
6+
"type": "lldb",
7+
"request": "launch",
8+
"program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
9+
"args": [],
10+
"cwd": "${workspaceFolder}"
11+
}
12+
]
13+
}

0 commit comments

Comments
 (0)