Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2471,12 +2471,12 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];

let arguments = (ins
Tosa_I1Tensor:$cond,
Variadic<Tosa_Tensor>:$inputs
Tosa_I1Tensor:$condition,
Variadic<Tosa_Tensor>:$input_list
);

let results = (outs
Variadic<Tosa_Tensor>:$output
Variadic<Tosa_Tensor>:$output_list
);

list<Availability> availability = [
Expand All @@ -2485,8 +2485,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
];

let regions = (region
SizedRegion<1>:$then_branch,
SizedRegion<1>:$else_branch
SizedRegion<1>:$then_graph,
SizedRegion<1>:$else_graph
);

let hasCustomAssemblyFormat = 1;
Expand All @@ -2513,11 +2513,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
}];

let arguments = (ins
Variadic<Tosa_Tensor>:$inputs
Variadic<Tosa_Tensor>:$input_list
);

let results = (outs
Variadic<Tosa_Tensor>:$output
Variadic<Tosa_Tensor>:$output_list
);

list<Availability> availability = [
Expand All @@ -2526,8 +2526,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
];

let regions = (region
SizedRegion<1>:$cond,
SizedRegion<1>:$body
SizedRegion<1>:$cond_graph,
SizedRegion<1>:$body_graph
);

let hasCustomAssemblyFormat = 1;
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
condition, true);

inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
rewriter);
inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
rewriter);

rewriter.replaceOp(op, newIf.getResults());
Expand Down Expand Up @@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
auto newWhile = rewriter.create<scf::WhileOp>(
op.getLoc(), op.getResultTypes(), op.getInputs());
op.getLoc(), op.getResultTypes(), op.getInputList());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());

inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);

rewriter.replaceOp(op, newWhile.getResults());

Expand Down
23 changes: 13 additions & 10 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
//===----------------------------------------------------------------------===//

/// Returns the while loop body.
SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
return {&getBodyGraph()};
}

//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
Expand Down Expand Up @@ -2536,7 +2538,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (auto &block : adaptor.getBody())
for (auto &block : adaptor.getBodyGraph())
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);

Expand Down Expand Up @@ -2616,19 +2618,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
void IfOp::print(OpAsmPrinter &p) {
bool printBlockTerminators = false;

p << " " << getCond();
p << " " << getCondition();
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ")";
// Print yield explicitly if the op defines values.
printBlockTerminators = true;
}
p << ' ';
p.printRegion(getThenBranch(),
p.printRegion(getThenGraph(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);

// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseBranch();
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
Expand Down Expand Up @@ -2726,14 +2728,15 @@ static void printInitializationList(OpAsmPrinter &parser,
}

void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCond().front().getArguments(), getInputs(),
" ");
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
parser << " : ";
parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
parser.printFunctionalType(getInputList().getTypes(),
getResults().getTypes());
parser << ' ';
parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
parser << " do ";
parser.printRegion(getBody());
parser.printRegion(getBodyGraph());
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
!levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
return false;
}
}
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
!levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
return false;
}
}
Expand Down Expand Up @@ -450,7 +450,7 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();

// Only the condition input has rank limitation.
if (!levelCheckRank(op, tosaOp.getCond(), "operand", tosaLevel.MAX_RANK))
if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
return false;

return true;
Expand Down