diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index e0f2fd411bbe4..8362092021b0b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2471,12 +2471,12 @@ def Tosa_IfOp : Tosa_Op<"cond_if", }]; let arguments = (ins - Tosa_I1Tensor:$cond, - Variadic:$inputs + Tosa_I1Tensor:$condition, + Variadic:$input_list ); let results = (outs - Variadic:$output + Variadic:$output_list ); list availability = [ @@ -2485,8 +2485,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if", ]; let regions = (region - SizedRegion<1>:$then_branch, - SizedRegion<1>:$else_branch + SizedRegion<1>:$then_graph, + SizedRegion<1>:$else_graph ); let hasCustomAssemblyFormat = 1; @@ -2513,11 +2513,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ }]; let arguments = (ins - Variadic:$inputs + Variadic:$input_list ); let results = (outs - Variadic:$output + Variadic:$output_list ); list availability = [ @@ -2526,8 +2526,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ ]; let regions = (region - SizedRegion<1>:$cond, - SizedRegion<1>:$body + SizedRegion<1>:$cond_graph, + SizedRegion<1>:$body_graph ); let hasCustomAssemblyFormat = 1; diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp index 9139bf191fdf1..ef144fc7f0d54 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::IfOp op, PatternRewriter &rewriter) const final { auto condition = - rewriter.create(op.getLoc(), op.getCond()); + rewriter.create(op.getLoc(), op.getCondition()); auto newIf = rewriter.create(op.getLoc(), op.getResultTypes(), condition, true); - inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(), + inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(), rewriter); - inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(), + inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(), rewriter); rewriter.replaceOp(op, newIf.getResults()); @@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::WhileOp op, PatternRewriter &rewriter) const final { auto newWhile = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs()); + op.getLoc(), op.getResultTypes(), op.getInputList()); rewriter.createBlock(&newWhile.getBefore()); rewriter.createBlock(&newWhile.getAfter()); - inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true); - inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false); + inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true); + inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false); rewriter.replaceOp(op, newWhile.getResults()); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 8841d53b6e64d..800968e6f4766 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -127,7 +127,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { //===----------------------------------------------------------------------===// /// Returns the while loop body. -SmallVector tosa::WhileOp::getLoopRegions() { return {&getBody()}; } +SmallVector tosa::WhileOp::getLoopRegions() { + return {&getBodyGraph()}; +} //===----------------------------------------------------------------------===// // Tosa dialect initialization. @@ -2536,7 +2538,7 @@ LogicalResult WhileOp::inferReturnTypeComponents( WhileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; - for (auto &block : adaptor.getBody()) + for (auto &block : adaptor.getBodyGraph()) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); @@ -2616,19 +2618,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { void IfOp::print(OpAsmPrinter &p) { bool printBlockTerminators = false; - p << " " << getCond(); + p << " " << getCondition(); if (!getResults().empty()) { p << " -> (" << getResultTypes() << ")"; // Print yield explicitly if the op defines values. printBlockTerminators = true; } p << ' '; - p.printRegion(getThenBranch(), + p.printRegion(getThenGraph(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); // Print the 'else' regions if it exists and has a block. - auto &elseRegion = getElseBranch(); + auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion, @@ -2726,14 +2728,15 @@ static void printInitializationList(OpAsmPrinter &parser, } void WhileOp::print(OpAsmPrinter &parser) { - printInitializationList(parser, getCond().front().getArguments(), getInputs(), - " "); + printInitializationList(parser, getCondGraph().front().getArguments(), + getInputList(), " "); parser << " : "; - parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes()); + parser.printFunctionalType(getInputList().getTypes(), + getResults().getTypes()); parser << ' '; - parser.printRegion(getCond(), /*printEntryBlockArgs=*/false); + parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false); parser << " do "; - parser.printRegion(getBody()); + parser.printRegion(getBodyGraph()); parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 96fb054d75b66..1060f520d2930 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -371,14 +371,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } } if (auto condIf = dyn_cast(op)) { - if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") || - !levelCheckListSize(op, condIf.getOutput().size(), "outputs")) { + if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") || + !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) { return false; } } if (auto w = dyn_cast(op)) { - if (!levelCheckListSize(op, w.getInputs().size(), "inputs") || - !levelCheckListSize(op, w.getOutput().size(), "outputs")) { + if (!levelCheckListSize(op, w.getInputList().size(), "inputs") || + !levelCheckListSize(op, w.getOutputList().size(), "outputs")) { return false; } } @@ -450,7 +450,7 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { auto op = tosaOp.getOperation(); // Only the condition input has rank limitation. - if (!levelCheckRank(op, tosaOp.getCond(), "operand", tosaLevel.MAX_RANK)) + if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK)) return false; return true;