diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index c0a25a56dbe2a..6e1e3343ac169 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -136,14 +136,14 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::MulOp if (isa(op)) { - auto shift_val = cast(op).getShift(); - DenseElementsAttr shift_elem; - if (!shift_val.getImpl() || - !matchPattern(shift_val, m_Constant(&shift_elem))) { + auto shiftVal = cast(op).getShift(); + DenseElementsAttr shiftElem; + if (!matchPattern(shiftVal, m_Constant(&shiftElem))) { (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); + return nullptr; } - int32_t shift = shift_elem.getValues()[0].getInt(); + int32_t shift = shiftElem.getValues()[0].getInt(); if (isa(elementTy)) { if (shift != 0) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index d00846a4c3e02..69d8471df8032 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -73,3 +73,11 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> { + // expected-error@+1 {{failed to legalize operation 'tosa.mul'}} + %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +}