Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 86 additions & 36 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3647,23 +3647,87 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}

static void printInitializationList(OpAsmPrinter &parser,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "") {
assert(blocksArgs.size() == initializers.size() &&
"expected same length of arguments and initializers");
if (initializers.empty())
return;

parser << prefix << '(';
llvm::interleaveComma(
llvm::zip(blocksArgs, initializers), parser,
[&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
parser << ")";
}

// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();

auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
// Create a i1 tensor type for the boolean condition.
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
if (parser.parseOperand(cond) ||
parser.resolveOperand(cond, i1Type, result.operands))

if (parser.parseOperand(cond))
return failure();
// Parse optional results type list.
if (parser.parseOptionalArrowTypeList(result.types))

SmallVector<OpAsmParser::Argument, 4> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;

// Parse the optional block arguments
OptionalParseResult listResult =
parser.parseOptionalAssignmentList(regionArgs, operands);
if (listResult.has_value() && failed(listResult.value()))
return failure();

// Parse a colon.
if (failed(parser.parseColon()))
return parser.emitError(parser.getCurrentLocation(),
"expected type for condition operand");

// Parse the type of the condition operand
Type condType;
if (failed(parser.parseType(condType)))
return parser.emitError(parser.getCurrentLocation(),
"expected type for condition operand");

// Resolve operand with provided type
if (failed(parser.resolveOperand(cond, condType, result.operands)))
return failure();

// Parse optional block arg types
if (listResult.has_value()) {
FunctionType functionType;

if (failed(parser.parseType(functionType)))
return parser.emitError(parser.getCurrentLocation())
<< "expected list of types for block arguments "
<< "followed by arrow type and list of return types";

result.addTypes(functionType.getResults());

if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(parser.getCurrentLocation())
<< "expected as many input types as operands "
<< "(expected " << operands.size() << " got "
<< functionType.getNumInputs() << ")";
}

// Resolve input operands.
if (failed(parser.resolveOperands(operands, functionType.getInputs(),
parser.getCurrentLocation(),
result.operands)))
return failure();
} else {
// Parse optional results type list.
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
}

// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
Expand All @@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}

void IfOp::print(OpAsmPrinter &p) {
bool printBlockTerminators = false;

p << " " << getCondition();
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ")";
// Print yield explicitly if the op defines values.
printBlockTerminators = true;

printInitializationList(p, getThenGraph().front().getArguments(),
getInputList(), " ");
p << " : ";
p << getCondition().getType();

if (!getInputList().empty()) {
p << " (";
llvm::interleaveComma(getInputList().getTypes(), p);
p << ")";
}
p << ' ';
p.printRegion(getThenGraph(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
p.printArrowTypeList(getResultTypes());
p << " ";

p.printRegion(getThenGraph());

// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
p.printRegion(elseRegion);
}

p.printOptionalAttrDict((*this)->getAttrs());
Expand Down Expand Up @@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}

static void printInitializationList(OpAsmPrinter &parser,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "") {
assert(blocksArgs.size() == initializers.size() &&
"expected same length of arguments and initializers");
if (initializers.empty())
return;

parser << prefix << '(';
llvm::interleaveComma(
llvm::zip(blocksArgs, initializers), parser,
[&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
parser << ")";
}

void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
Expand Down
14 changes: 6 additions & 8 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
// %0 = tosa.cond_if %arg2 {
// tosa.yield %arg0
// %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
// ^bb0(%arg3, %arg4):
// tosa.yield %arg3
// } else {
// tosa.yield %arg1
// ^bb0(%arg3, %arg4):
// tosa.yield %arg4
// }
//
// Unfortunately, the simplified syntax does not encapsulate values
// used in then/else regions (see 'simplified' example above), so it
// must be rewritten to use the generic syntax in order to be conformant
// to the specification.

return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {

// CHECK: scf.yield [[ARG0]]
tosa.yield %arg0 : tensor<f32>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/availability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if profiles: [ ]
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/Tosa/controlflow.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: mlir-opt -split-input-file %s | FileCheck %s

// -----

func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> {
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
// CHECK: } else {
} else {
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
}
return %0 : tensor<f32>
}

// -----

func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
%0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
// CHECK: } else {
// CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
} else {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
}
return %0 : tensor<f32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/error_if_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:

func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<f32>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
// -----
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
// -----

func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%1 = tosa.cond_if %arg3 -> (tensor<f32>) {
%2 = tosa.cond_if %arg2 -> (tensor<f32>) {
%3 = tosa.cond_if %arg3 -> (tensor<f32>) {
%4 = tosa.cond_if %arg2 -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
%2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
%4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
%5 = tosa.cond_if %arg3 -> (tensor<f32>) {
%5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %res : tensor<f32>
} else {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK-LABEL: test_regions
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
// CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
// CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1166,8 +1166,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
%b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>

// CHECK: tosa.cond_if
// CHECK: -> (tensor<f32>)
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
// CHECK: -> tensor<f32>
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
tosa.yield %a : tensor<f32>
} else {
tosa.yield %b : tensor<f32>
Expand All @@ -1180,8 +1180,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
// CHECK-LABEL: @if_test_dynamic
func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
// CHECK: -> (tensor<?xf32>)
%0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
// CHECK: -> tensor<?xf32>
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> {
tosa.yield %arg0 : tensor<2xf32>
} else {
tosa.yield %arg1 : tensor<3xf32>
Expand All @@ -1194,8 +1194,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_unranked
func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
// CHECK: -> (tensor<*xf32>)
%0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
// CHECK: -> tensor<*xf32>
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<3xf32>
Expand All @@ -1208,8 +1208,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_propagate
func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
// CHECK: -> (tensor<f32>)
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
// CHECK: -> tensor<f32>
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
Expand Down
Loading