From 3ac41911362c402a7bb3be46e7e58838be196ee5 Mon Sep 17 00:00:00 2001 From: Aditya Pradhan Date: Tue, 22 Jul 2025 14:01:04 +0530 Subject: [PATCH] [TorchToLinalg] Fix the lowering of `AtenIndexTensorHackedTwinOp` In some cases, the index operand may not be a 64-bit integer. To ensure type compatibility in `arith.cmpi`, the index operand should be sign-extended to 64 bits when required. --- .../TorchToLinalg/IndirectDataMovement.cpp | 8 ++++++++ test/Conversion/TorchToLinalg/basic.mlir | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 07e4b23a167d..d730089b5ca6 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -507,6 +507,14 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index, Value input, int64_t dim) { Value cstZero = b.create(loc, b.getI64IntegerAttr(0)); + auto indexType = mlir::dyn_cast(index.getType()); + unsigned maxBitWidth = 64; + assert((indexType && indexType.isSignless()) && + "The index operand must be a signless integer"); + assert((indexType.getWidth() <= maxBitWidth) && + "Maximum supported bitwidth of the index operand is 64"); + if (indexType.getWidth() < maxBitWidth) + index = b.create(loc, b.getIntegerType(maxBitWidth), index); Value isIndexNegative = b.create(loc, arith::CmpIPredicate::slt, index, cstZero); Value inputShape = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim)); diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index c7d6149a8fcd..0beeeeb85ed0 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -459,3 +459,21 @@ func.func @torch.ops.aten.replication_pad3d$basic(%arg0: !torch.vtensor<[4,3,5], %0 = torch.aten.replication_pad3d %arg0, %padding : !torch.vtensor<[4,3,5],f32>, !torch.list -> !torch.vtensor<[7,7,6],f32> return %0 : !torch.vtensor<[7,7,6],f32> } + +// ----- + +// This test verifies that the index argument is properly sign-extended, +// when torch.aten.index.Tensor_hacked_twin is lowered into a linalg.generic +// operation. +// +// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[IN:.*]]: i32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : i64 +// CHECK-NEXT: %[[IN_SIGN_EXT:.*]] = arith.extsi %[[IN]] : i32 to i64 +// CHECK-NEXT: arith.cmpi slt, %[[IN_SIGN_EXT]], %[[C0]] : i64 +func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[1,1,8],si32>, %arg1: !torch.vtensor<[16], f32>) -> !torch.vtensor<[1,1,8],f32> { + %0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1,1,8],si32>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg1, %0 : !torch.vtensor<[16],f32>, !torch.list -> !torch.vtensor<[1,1,8],f32> + return %1 : !torch.vtensor<[1,1,8],f32> +}