Skip to content

Commit 72508c7

Browse files
[TOSA] Retag resource literals to signless constants (#4367)
- Extend ValueTensorLiteral lowering so DenseResourceElementsAttr integers are rebuilt with signless element types before emitting tosa.const, matching the converted tensor type. - Add lit coverage for resource-backed i32/i64 vtensor literals. - Add FX importer e2e tests that return constant int32/int64 tensors.
1 parent 244f4b6 commit 72508c7

File tree

4 files changed

+101
-2
lines changed

4 files changed

+101
-2
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1414
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1515
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
16+
#include "mlir/IR/DialectResourceBlobManager.h"
1617
#include "mlir/IR/Matchers.h"
1718
#include "mlir/Pass/Pass.h"
1819
#include "mlir/Transforms/DialectConversion.h"
@@ -3161,7 +3162,21 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
31613162
return success();
31623163
}
31633164
}
3164-
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, adaptor.getValue());
3165+
ElementsAttr attr = cast<ElementsAttr>(adaptor.getValue());
3166+
if (auto res = dyn_cast<DenseResourceElementsAttr>(attr)) {
3167+
// Resource blobs preserve the producer's signedness, so retag them here to
3168+
// keep TOSA constants signless and avoid downstream type mismatches.
3169+
auto shapedAttrTy = cast<ShapedType>(res.getType());
3170+
if (auto intTy = dyn_cast<IntegerType>(shapedAttrTy.getElementType())) {
3171+
if (!intTy.isSignless()) {
3172+
auto signlessTy =
3173+
IntegerType::get(rewriter.getContext(), intTy.getWidth());
3174+
auto newTy = RankedTensorType::get(shapedAttrTy.getShape(), signlessTy);
3175+
attr = DenseResourceElementsAttr::get(newTy, res.getRawHandle());
3176+
}
3177+
}
3178+
}
3179+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, attr);
31653180
return success();
31663181
}
31673182

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,8 @@
680680
"ChannelShuffleTrailingOnes_basic",
681681
"ChannelShuffleDynamicDims_basic",
682682
"ConstantBoolParameterModule_basic",
683+
"ConstantInt32ParameterModule_basic",
684+
"ConstantInt64ParameterModule_basic",
683685
"ContainsIntList_False",
684686
"ContainsIntList_True",
685687
"Conv2dFP16NoBiasModule_basic",
@@ -2882,6 +2884,8 @@
28822884
"ColumnStack1dModule_basic",
28832885
"ColumnStack0dModule_basic",
28842886
"ConstantBoolParameterModule_basic",
2887+
"ConstantInt32ParameterModule_basic",
2888+
"ConstantInt64ParameterModule_basic",
28852889
"ContainsIntList_False",
28862890
"ContainsIntList_True",
28872891
"Conv1dModule_basic",
@@ -3671,7 +3675,6 @@
36713675
"BoolIntTrueModule_basic",
36723676
"BroadcastDynamicDimModule_basic",
36733677
"CeilFloatModule_basic",
3674-
"ConstantBoolParameterModule_basic",
36753678
"ContainsIntList_False",
36763679
"ContainsIntList_True",
36773680
"Conv1dModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,6 +2976,49 @@ def TensorIntModule_basic(module, tu: TestUtils):
29762976
# ==============================================================================
29772977

29782978

2979+
class ConstantInt32ParameterModule(torch.nn.Module):
2980+
def __init__(self):
2981+
super().__init__()
2982+
self.tensor = torch.tensor([0, 10, 128, 17000], dtype=torch.int32)
2983+
2984+
@export
2985+
@annotate_args(
2986+
[
2987+
None,
2988+
]
2989+
)
2990+
def forward(self):
2991+
return self.tensor
2992+
2993+
2994+
@register_test_case(module_factory=lambda: ConstantInt32ParameterModule())
2995+
def ConstantInt32ParameterModule_basic(module, tu: TestUtils):
2996+
module.forward()
2997+
2998+
2999+
class ConstantInt64ParameterModule(torch.nn.Module):
3000+
def __init__(self):
3001+
super().__init__()
3002+
self.tensor = torch.tensor([1, -2, 3, -4], dtype=torch.int64)
3003+
3004+
@export
3005+
@annotate_args(
3006+
[
3007+
None,
3008+
]
3009+
)
3010+
def forward(self):
3011+
return self.tensor
3012+
3013+
3014+
@register_test_case(module_factory=lambda: ConstantInt64ParameterModule())
3015+
def ConstantInt64ParameterModule_basic(module, tu: TestUtils):
3016+
module.forward()
3017+
3018+
3019+
# ==============================================================================
3020+
3021+
29793022
class tensorFloatModule(torch.nn.Module):
29803023
def __init__(self):
29813024
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,44 @@ func.func @torch.vtensor.literal_si32$basic() -> !torch.vtensor<[1,512],si32> {
10511051

10521052
// -----
10531053

1054+
// CHECK-LABEL: @torch.vtensor.literal_resource_si32$basic(
1055+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<torch_resource_i32> : tensor<4xi32>}>
1056+
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<4xi32> -> !torch.vtensor<[4],si32>
1057+
// CHECK: return %[[RET]] : !torch.vtensor<[4],si32>
1058+
func.func @torch.vtensor.literal_resource_si32$basic() -> !torch.vtensor<[4],si32> {
1059+
%0 = torch.vtensor.literal(dense_resource<torch_resource_i32> : tensor<4xsi32>) : !torch.vtensor<[4],si32>
1060+
return %0 : !torch.vtensor<[4],si32>
1061+
}
1062+
1063+
{-#
1064+
dialect_resources: {
1065+
builtin: {
1066+
torch_resource_i32: "0x08000000000000000a0000008000000068420000"
1067+
}
1068+
}
1069+
#-}
1070+
1071+
// -----
1072+
1073+
// CHECK-LABEL: @torch.vtensor.literal_resource_si64$basic(
1074+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<torch_resource_i64> : tensor<3xi64>}>
1075+
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<3xi64> -> !torch.vtensor<[3],si64>
1076+
// CHECK: return %[[RET]] : !torch.vtensor<[3],si64>
1077+
func.func @torch.vtensor.literal_resource_si64$basic() -> !torch.vtensor<[3],si64> {
1078+
%0 = torch.vtensor.literal(dense_resource<torch_resource_i64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
1079+
return %0 : !torch.vtensor<[3],si64>
1080+
}
1081+
1082+
{-#
1083+
dialect_resources: {
1084+
builtin: {
1085+
torch_resource_i64: "0x08000000010000000000000002000000000000000300000000000000"
1086+
}
1087+
}
1088+
#-}
1089+
1090+
// -----
1091+
10541092
// CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> {
10551093
// CHECK: %[[VAL_0:.*]] = torch.constant.none
10561094
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0

0 commit comments

Comments
 (0)