Skip to content

Commit f7a92d3

Browse files
Mi-Jiazhiqingyunqu
andauthored
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like: ``` import torch def my_triu(x, diag): rows = torch.ops.aten.size(x, -2) cols = torch.ops.aten.size(x, -1) row_indices = torch.ops.aten.arange(rows).unsqueeze(1) col_indices = torch.ops.aten.arange(cols).unsqueeze(0) cond = torch.ops.aten.ge( col_indices, torch.ops.aten.add(row_indices, diag)) return torch.ops.aten.where(cond, x, 0) x = torch.rand(5, 7) assert torch.allclose(my_triu(x, 0), torch.triu(x, 0)) assert torch.allclose(my_triu(x, 1), torch.triu(x, 1)) assert torch.allclose(my_triu(x, 2), torch.triu(x, 2)) assert torch.allclose(my_triu(x, -1), torch.triu(x, -1)) ``` --------- Co-authored-by: LiuYuanqiang <[email protected]>
1 parent 49fdc1a commit f7a92d3

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,62 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
246246
};
247247
} // end namespace
248248

249+
namespace {
250+
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
251+
public:
252+
using OpRewritePattern::OpRewritePattern;
253+
LogicalResult matchAndRewrite(AtenTriuOp op,
254+
PatternRewriter &rewriter) const override {
255+
MLIRContext *context = op.getContext();
256+
Location loc = op.getLoc();
257+
Value input = op.getSelf();
258+
auto inputType = input.getType().cast<BaseTensorType>();
259+
if (!inputType.hasSizes() || !inputType.hasDtype()) {
260+
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
261+
}
262+
if (inputType.getSizes().size() < 2) {
263+
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
264+
}
265+
266+
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
267+
Value cstZero =
268+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
269+
Value cstOne =
270+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
271+
Value none = rewriter.create<ConstantNoneOp>(loc);
272+
273+
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
274+
loc, rewriter.getI64IntegerAttr(-2));
275+
Value colDim = rewriter.create<Torch::ConstantIntOp>(
276+
loc, rewriter.getI64IntegerAttr(-1));
277+
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
278+
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
279+
280+
Value rowArange = rewriter.create<AtenArangeOp>(
281+
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
282+
/*device=*/none, /*pin_memory=*/none);
283+
Value colArange = rewriter.create<AtenArangeOp>(
284+
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
285+
/*device=*/none, /*pin_memory=*/none);
286+
287+
Value unsqueezeRowArange =
288+
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
289+
Value unsqueezeColArange =
290+
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
291+
292+
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
293+
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
294+
295+
Value condTensor = rewriter.create<AtenGeTensorOp>(
296+
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
297+
298+
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
299+
op, op.getResult().getType(), condTensor, input, cstZero);
300+
return success();
301+
}
302+
};
303+
} // namespace
304+
249305
namespace {
250306
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
251307
public:
@@ -5817,6 +5873,7 @@ class DecomposeComplexOpsPass
58175873
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
58185874
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
58195875
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
5876+
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
58205877

58215878
GreedyRewriteConfig config;
58225879
config.useTopDownTraversal = true;

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
500500
target.addIllegalOp<AtenTypeAsOp>();
501501
target.addIllegalOp<AtenTileOp>();
502502
target.addIllegalOp<AtenReshapeAsOp>();
503+
target.addIllegalOp<AtenTriuOp>();
503504
for (auto &opName : backendLegalOpsSet) {
504505
target.addLegalOp(
505506
OperationName(kTorchOpPrefix + opName.first().str(), context));

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,6 +3251,52 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
32513251
# ==============================================================================
32523252

32533253

3254+
class TriuModule(torch.nn.Module):
3255+
def __init__(self):
3256+
super().__init__()
3257+
3258+
@export
3259+
@annotate_args([
3260+
None,
3261+
([4,5], torch.float32, True),
3262+
])
3263+
def forward(self, x):
3264+
return torch.ops.aten.triu(x, 1)
3265+
3266+
3267+
@register_test_case(module_factory=lambda: TriuModule())
3268+
def TriuModule_basic(module, tu: TestUtils):
3269+
x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2],
3270+
[-0.2447, 0.9556, -1.2919, 1.3378, 0.3],
3271+
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.4],
3272+
[-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]])
3273+
module.forward(x)
3274+
3275+
3276+
# ==============================================================================
3277+
3278+
3279+
class TriuBroadcastModule(torch.nn.Module):
3280+
def __init__(self):
3281+
super().__init__()
3282+
3283+
@export
3284+
@annotate_args([
3285+
None,
3286+
([3,4,5,6], torch.float32, True),
3287+
])
3288+
def forward(self, x):
3289+
return torch.ops.aten.triu(x, 2)
3290+
3291+
3292+
@register_test_case(module_factory=lambda: TriuBroadcastModule())
3293+
def TriuBroadcastModule_basic(module, tu: TestUtils):
3294+
module.forward(tu.rand(3,4,5,6))
3295+
3296+
3297+
# ==============================================================================
3298+
3299+
32543300
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
32553301

32563302
def __init__(self):

0 commit comments

Comments
 (0)