Skip to content

Commit bbd46d0

Browse files
[TOSA] Fix empty-dim reductions (#4365)
Teach the TorchToTosa reducer that an explicit empty dim list means "all dims" and cast the result back to the requested dtype. Add MLIR and e2e regression cases and update XFAILs.
1 parent 72508c7 commit bbd46d0

File tree

5 files changed

+111
-3
lines changed

5 files changed

+111
-3
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,11 @@ class ConvertAtenMultipleDimsReductionOp
11121112
for (int64_t i = 0; i < inputRank; i++)
11131113
reduceDims.push_back(i);
11141114
}
1115+
// PyTorch treats an explicit empty list the same as "reduce all dims".
1116+
if (reduceDims.empty()) {
1117+
for (int64_t i = 0; i < inputRank; i++)
1118+
reduceDims.push_back(i);
1119+
}
11151120

11161121
int64_t N = reduceDims.size();
11171122
for (unsigned i = 0; i < N; i++) {

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,13 +785,23 @@ std::optional<Value> convertReduceOpCommon(
785785

786786
// Optionally squeeze out the reduced axes.
787787
if (!keep_dims) {
788+
auto squeezedType =
789+
RankedTensorType::get(output_shape, reduce_element_type);
788790
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
789-
rewriter, op->getLoc(), output_type, val,
791+
rewriter, op->getLoc(), squeezedType, val,
790792
tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape));
791793
val = reshape_op.getResult();
792794
}
793795
}
794796

797+
// Ensure the result element type matches the expected output type.
798+
if (val.getType() != output_type) {
799+
auto casted = tosa::tosaCastTensorToType(rewriter, val, output_type);
800+
if (!casted)
801+
return std::nullopt;
802+
val = casted.value();
803+
}
804+
795805
return val;
796806
}
797807

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3435,6 +3435,8 @@
34353435
"ElementwiseClampMinModule_bfloat16",
34363436
"ElementwiseClampModule_bfloat16",
34373437
"ElementwiseReluModule_bfloat16",
3438+
# torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '<class 'torch.TensorType'>'
3439+
"ReduceSumEmptyDimListInt8ToInt32Module_basic",
34383440
}
34393441

34403442
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3825,7 +3827,6 @@
38253827
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
38263828
"MaxPool3dWithIndicesStaticModule_basic",
38273829
"MaxPool3dSingleIntTupleDilationModule_basic",
3828-
"MeanDimEmptyDimModule_basic",
38293830
"MlGroupNormManualModule_basic",
38303831
"MlGroupNormModule_basic",
38313832
"MlLayerNormManualModule_basic",
@@ -3880,7 +3881,6 @@
38803881
"ReduceL3NormKeepDimComplexModule_basic",
38813882
"ReduceMaxAlongDimUnsignedInt_basic",
38823883
"ReduceMinAlongDimUnsignedInt_basic",
3883-
"ReduceSumDimIntListEmptyDimModule_basic",
38843884
"RollModule_basic",
38853885
"ScalarConstantTupleModule_basic",
38863886
"ScalarImplicitFloatModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,52 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
5858
# ==============================================================================
5959

6060

61+
class ReduceSumEmptyDimListInt8ToInt32Module(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
65+
@export
66+
@annotate_args(
67+
[
68+
None,
69+
([-1, -1, -1], torch.int8, True),
70+
]
71+
)
72+
def forward(self, a):
73+
return torch.sum(a, dim=[], dtype=torch.int32)
74+
75+
76+
@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8ToInt32Module())
77+
def ReduceSumEmptyDimListInt8ToInt32Module_basic(module, tu: TestUtils):
78+
module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8))
79+
80+
81+
# ==============================================================================
82+
83+
84+
class ReduceSumEmptyDimListInt8Module(torch.nn.Module):
85+
def __init__(self):
86+
super().__init__()
87+
88+
@export
89+
@annotate_args(
90+
[
91+
None,
92+
([-1, -1, -1], torch.int8, True),
93+
]
94+
)
95+
def forward(self, a):
96+
return torch.sum(a, dim=[])
97+
98+
99+
@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8Module())
100+
def ReduceSumEmptyDimListInt8Module_basic(module, tu: TestUtils):
101+
module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8))
102+
103+
104+
# ==============================================================================
105+
106+
61107
class ReduceSumElementTypeBoolModule(torch.nn.Module):
62108
def __init__(self):
63109
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,53 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> !
325325

326326
// -----
327327

328+
// CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic(
329+
// CHECK-SAME: %[[INPUT_F32:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> {
330+
// CHECK: %[[INPUT_F32_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_F32]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
331+
// CHECK: %[[NONE:.*]] = torch.constant.none
332+
// CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
333+
// CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[INPUT_F32_TENSOR]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32>
334+
// CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32>
335+
// CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32>
336+
// CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape
337+
// CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor<f32>
338+
// CHECK: %[[RESULT_F32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor<f32> -> !torch.vtensor<[],f32>
339+
// CHECK: return %[[RESULT_F32]] : !torch.vtensor<[],f32>
340+
// CHECK: }
341+
func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> {
342+
%dtype_none = torch.constant.none
343+
%keep_dims_false = torch.constant.bool false
344+
%all_dims_list = torch.prim.ListConstruct : () -> !torch.list<int>
345+
%sum_all_dims = torch.aten.sum.dim_IntList %arg0, %all_dims_list, %keep_dims_false, %dtype_none : !torch.vtensor<[2,3,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
346+
return %sum_all_dims : !torch.vtensor<[],f32>
347+
}
348+
349+
// -----
350+
351+
// CHECK-LABEL: func.func @test_reduce_sum_empty_dims_i8_to_i32$basic(
352+
// CHECK-SAME: %[[INPUT_I8:.*]]: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> {
353+
// CHECK: %[[INPUT_I8_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_I8]] : !torch.vtensor<[2,3,4],si8> -> tensor<2x3x4xi8>
354+
// CHECK: %[[DTYPE_I32:.*]] = torch.constant.int 3
355+
// CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
356+
// CHECK: %[[CAST_INPUT_TO_I32:.*]] = tosa.cast %[[INPUT_I8_TENSOR]] : (tensor<2x3x4xi8>) -> tensor<2x3x4xi32>
357+
// CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[CAST_INPUT_TO_I32]] {axis = 0 : i32} : (tensor<2x3x4xi32>) -> tensor<1x3x4xi32>
358+
// CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xi32>) -> tensor<1x1x4xi32>
359+
// CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xi32>) -> tensor<1x1x1xi32>
360+
// CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape
361+
// CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xi32>, !tosa.shape<0>) -> tensor<i32>
362+
// CHECK: %[[RESULT_I32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor<i32> -> !torch.vtensor<[],si32>
363+
// CHECK: return %[[RESULT_I32]] : !torch.vtensor<[],si32>
364+
// CHECK: }
365+
func.func @test_reduce_sum_empty_dims_i8_to_i32$basic(%arg0: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> {
366+
%dtype_i32 = torch.constant.int 3
367+
%keep_dims_false = torch.constant.bool false
368+
%all_dims_list = torch.prim.ListConstruct : () -> !torch.list<int>
369+
%sum_all_dims_to_i32 = torch.aten.sum.dim_IntList %arg0, %all_dims_list, %keep_dims_false, %dtype_i32 : !torch.vtensor<[2,3,4],si8>, !torch.list<int>, !torch.bool, !torch.int -> !torch.vtensor<[],si32>
370+
return %sum_all_dims_to_i32 : !torch.vtensor<[],si32>
371+
}
372+
373+
// -----
374+
328375
// CHECK-LABEL: func.func @test_linalg_vector_norm$basic(
329376
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> {
330377
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>

0 commit comments

Comments
 (0)