Skip to content

Commit 5c37c6e

Browse files
committed
[TorchToLinalg] Fix the lowering of AtenIndexTensorHackedTwinOp
In some cases, the index operand may not be a 64-bit integer. To ensure type compatibility in `arith.cmpi`, the index operand should be sign-extended to 64 bits when required.
1 parent 46c3888 commit 5c37c6e

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,12 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
507507
static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index,
508508
Value input, int64_t dim) {
509509
Value cstZero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
510+
auto indexType = mlir::dyn_cast<IntegerType>(index.getType());
511+
unsigned maxBitWidth = 64;
512+
assert((indexType && indexType.getWidth() <= maxBitWidth) &&
513+
"Maximum supported bitwidth of the index operand is 64");
514+
if (indexType.getWidth() < maxBitWidth)
515+
index = b.create<arith::ExtSIOp>(loc, b.getIntegerType(maxBitWidth), index);
510516
Value isIndexNegative =
511517
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, index, cstZero);
512518
Value inputShape = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim));

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,21 @@ func.func @torch.ops.aten.replication_pad3d$basic(%arg0: !torch.vtensor<[4,3,5],
459459
%0 = torch.aten.replication_pad3d %arg0, %padding : !torch.vtensor<[4,3,5],f32>, !torch.list<int> -> !torch.vtensor<[7,7,6],f32>
460460
return %0 : !torch.vtensor<[7,7,6],f32>
461461
}
462+
463+
// -----
464+
465+
// This test verifies that the index argument is properly sign-extended,
466+
// when torch.aten.index.Tensor_hacked_twin is lowered into a linalg.generic
467+
// operation.
468+
//
469+
// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
470+
// CHECK: linalg.generic
471+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: i32, %[[OUT:.*]]: f32):
472+
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : i64
473+
// CHECK-NEXT: %[[IN_SIGN_EXT:.*]] = arith.extsi %[[IN]] : i32 to i64
474+
// CHECK-NEXT: arith.cmpi slt, %[[IN_SIGN_EXT]], %[[C0]] : i64
475+
func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[1,1,8],si32>, %arg1: !torch.vtensor<[16], f32>) -> !torch.vtensor<[1,1,8],f32> {
476+
%0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1,1,8],si32>) -> !torch.list<vtensor>
477+
%1 = torch.aten.index.Tensor_hacked_twin %arg1, %0 : !torch.vtensor<[16],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,1,8],f32>
478+
return %1 : !torch.vtensor<[1,1,8],f32>
479+
}

0 commit comments

Comments
 (0)