Skip to content

Commit 0189977

Browse files
authored
[OnnxToTorch] Support non-rank-0 Loop index tensors (#4098)
The `onnx.Loop` index value isn't required to be rank-0 (e.g. it may have a shape of `[1]`), so we can't use `NumToTensor.Scalar` since it creates rank-0 tensors. This switches to using `aten.full` instead, which eventually canonicalizes to `NumToTensor.Scalar` in the rank-0 case.
1 parent b6c4e87 commit 0189977

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -376,20 +376,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
376376
// primLoopOp loopBody expects torch.int as first arg
377377
// insert torch.int arg in loop body, convert to tensor,
378378
// replace all uses of old arg, delete old arg.
379-
auto loopVarArg = loop.getRegion().front().getArgument(0);
379+
auto loopVar = loop.getRegion().front().getArgument(0);
380380
// insert new Arg
381381
loop.getRegion().front().insertArgument(
382382
0U, rewriter.getType<Torch::IntType>(), binder.getLoc());
383383
auto newLoopVarArg = loop.getRegion().front().getArgument(0);
384384

385385
// convert int arg to tensor of original Type
386386
rewriter.setInsertionPointToStart(&loop.getRegion().front());
387-
Value loopVarVal = BlockArgument::Value(loopVarArg);
388-
auto newTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
389-
loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(),
390-
newLoopVarArg);
391-
392-
loopVarArg.replaceAllUsesWith(newTensor);
387+
auto loopVarType = dyn_cast<Torch::BaseTensorType>(loopVar.getType());
388+
if (!loopVarType || !loopVarType.areAllSizesKnown())
389+
return rewriter.notifyMatchFailure(
390+
loopVar.getLoc(),
391+
"loop iteration value must be a tensor with known sizes");
392+
Value sizes = Torch::toIntListConstruct(rewriter, loopVar.getLoc(),
393+
loopVarType.getSizes());
394+
auto newTensor = torch::Torch::createInitTensor(
395+
rewriter, loopVar.getLoc(), loopVarType, newLoopVarArg, sizes);
396+
397+
loopVar.replaceAllUsesWith(newTensor);
393398
loop.getRegion().eraseArgument(1);
394399

395400
// primLoopOp loopBody has no condition arg

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1977,7 +1977,10 @@ func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtens
19771977
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
19781978
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[MAX_TRIP_COUNT_INT]], %[[TRUE]], init(%[[LCD_1]]) {
19791979
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[LCD_1_BODY:.*]]: !torch.vtensor<[1],f32>):
1980-
// CHECK: %[[ITER_NUM_T:.*]] = torch.prim.NumToTensor.Scalar %[[ITER_NUM]] : !torch.int -> !torch.vtensor<[],si64>
1980+
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct : () -> !torch.list
1981+
// CHECK: %[[NONE_ITER_NUM_T:.*]] = torch.constant.none
1982+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
1983+
// CHECK: %[[ITER_NUM_T:.*]] = torch.aten.full %[[SHAPE]], %[[ITER_NUM]], %[[DTYPE]], %[[NONE_ITER_NUM_T]], %[[NONE_ITER_NUM_T]], %[[NONE_ITER_NUM_T]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],si64>
19811984
// CHECK: %[[NONE_1:.*]] = torch.constant.none
19821985
// CHECK: %[[CLONE_INP_COND:.*]] = torch.aten.clone %[[CONDITION_INP]], %[[NONE_1]] : !torch.vtensor<[],i1>, !torch.none -> !torch.vtensor<[],i1>
19831986
// CHECK: %[[CONST_ARR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>) : !torch.vtensor<[5],f32>

0 commit comments

Comments
 (0)