Skip to content

Commit 55ff110

Browse files
authored
[MLIR][TORCH] Only unroll prim loop-like ops within a torch.shape.calculate region (#3812)
Reports a match failure for the pattern `FullyUnrollPrimLoop` when the loop op is not in a region defined by a `torch.shape.calculate` op. This is needed to avoid unrolling prim loops generated by ONNX IR, since we are applying shape refinement in the `torch-onnx-to-torch-backend-pipeline` introduced in fa4794d . See also the discussion in <iree-org/iree#18867 (comment)>
1 parent aca33f1 commit 55ff110

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
3232
} // namespace
3333

3434
namespace {
35-
// TODO: Only unroll inside the shape calculation region.
36-
// Maybe do this by only applying patterns and folding greedily on the ops
37-
// inside the region + the shape.calculate op itself?
3835
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
3936
public:
4037
using OpRewritePattern::OpRewritePattern;
4138
LogicalResult matchAndRewrite(PrimLoopOp op,
4239
PatternRewriter &rewriter) const override {
4340
Location loc = op->getLoc();
4441
MLIRContext *context = op->getContext();
42+
// Only unroll loops if they are contained in a shape calculate region.
43+
Region *region = op->getParentRegion();
44+
Operation *parentOp = region->getParentOp();
45+
if (!parentOp || !isa<Torch::ShapeCalculateOp>(parentOp))
46+
return rewriter.notifyMatchFailure(
47+
op, "Loop is not contained in a shape calculation region.");
4548
if (!op.isForLike())
4649
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
4750
int64_t maxTripCount;

test/Dialect/Torch/simplify-shape-calculations.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,23 @@ func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch
152152
return %0 : !torch.vtensor
153153
}
154154

155+
// CHECK-LABEL: func.func @fully_unroll_prim_loop$outside_region(
156+
// CHECK: %[[LOOP:.*]] = torch.prim.Loop
157+
func.func @fully_unroll_prim_loop$outside_region(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
158+
%true = torch.constant.bool true
159+
%0 = torch.prim.Loop %arg2, %true, init(%arg0) {
160+
^bb0(%arg3: !torch.int, %arg4: !torch.vtensor):
161+
%1 = torch.shape.calculate {
162+
torch.shape.calculate.yield %arg4 : !torch.vtensor
163+
} shapes {
164+
torch.prim.Print(%arg3) : !torch.int
165+
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
166+
} : !torch.vtensor
167+
torch.prim.Loop.condition %true, iter(%1 : !torch.vtensor)
168+
} : (!torch.int, !torch.bool, !torch.vtensor) -> !torch.vtensor
169+
return %0 : !torch.vtensor
170+
}
171+
155172
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic(
156173
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
157174
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,

0 commit comments

Comments
 (0)