Skip to content

Commit fbc1038

Browse files
authored
[mlir][TosaToLinalg] Only support ranked tensor for reduce and gather (#131805)
This PR adds checks for ranked tensors in converter of reduce and gather to prevent crash. Fixes #131087.
1 parent d2c41fb commit fbc1038

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,8 +1176,11 @@ template <typename OpTy>
11761176
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11771177
PatternRewriter &rewriter) {
11781178
auto loc = op->getLoc();
1179-
auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1180-
auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1179+
auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1180+
auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1181+
if (!inputTy || !resultTy)
1182+
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1183+
11811184
auto elementTy = resultTy.getElementType();
11821185
Value input = op->getOperand(0);
11831186

@@ -2380,11 +2383,9 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
23802383
auto input = adaptor.getOperands()[0];
23812384
auto indices = adaptor.getOperands()[1];
23822385

2383-
auto valuesTy =
2384-
dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2385-
auto resultTy = cast<ShapedType>(op.getType());
2386-
2387-
if (!valuesTy)
2386+
auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2387+
auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2388+
if (!valuesTy || !resultTy)
23882389
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
23892390

23902391
auto dynamicDims = inferDynamicDimsForGather(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,19 @@ func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!
5757
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
5858
return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
5959
}
60+
61+
// -----
62+
63+
func.func @unranked_reduce(%arg0: tensor<*xf32>) -> tensor<*xf32> {
64+
// expected-error@+1 {{failed to legalize operation 'tosa.reduce_sum'}}
65+
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
66+
return %0 : tensor<*xf32>
67+
}
68+
69+
// -----
70+
71+
func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<*xf32> {
72+
// expected-error@+1 {{failed to legalize operation 'tosa.gather'}}
73+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
74+
return %0 : tensor<*xf32>
75+
}

0 commit comments

Comments
 (0)