diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index ecddc9fe9a13d..b876880597dd4 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -396,7 +396,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [ }]; let arguments = (ins - Tosa_Tensor3D:$input, + Tosa_Tensor3D:$input_real, DefaultValuedOptionalAttr:$local_bound ); @@ -411,7 +411,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [ ]; let assemblyFormat = [{ - $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)` + $input_real attr-dict `:` `(` type($input_real) `)` `->` `(` type($output_real) `,` type($output_imag) `)` }]; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index f7dd33c7e8b53..ccfe28270322c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2614,7 +2614,7 @@ struct RFFT2dConverter final : public OpRewritePattern { } auto loc = rfft2d.getLoc(); - auto input = rfft2d.getInput(); + auto input = rfft2d.getInputReal(); auto elementType = dyn_cast(cast(input.getType()).getElementType()); if (!elementType) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 7a991b3876f69..ffa540426cd71 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -805,7 +805,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, RFFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape(adaptor.getInput().getType()); + ShapeAdaptor inputShape(adaptor.getInputReal().getType()); if (!inputShape.hasRank()) return failure(); @@ -842,7 +842,8 @@ LogicalResult tosa::RFFT2dOp::verify() { if (failed(verifyCompatibleShapes(outputTypes))) return emitOpError("expected output shapes to match, got ") << outputTypes; - const auto inputType = llvm::dyn_cast(getInput().getType()); + const auto inputType = + llvm::dyn_cast(getInputReal().getType()); if (!inputType) return success(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 983062ffd7912..448a7e1982276 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -160,7 +160,7 @@ void ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { template <> void ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { - addValue(op.getInput()); + addValue(op.getInputReal()); addValue(op.getOutputReal()); addValue(op.getOutputImag()); }