Skip to content

Commit 4127458

Browse files
authored
[mlir][tosa] Fix check for isolated regions in tosa.cond_if (#143772)
This commit fixes a check in the validation pass which intended to validate whether a `tosa.cond_if` operation was conformant to the specification. The specification requires all values used in the then/else regions are explicitly declared within the regions. This change checks that these regions are 'isolated from above', to ensure this requirement is true.
1 parent 20d8398 commit 4127458

File tree

2 files changed

+97
-36
lines changed

2 files changed

+97
-36
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,32 +1193,55 @@ bool checkErrorIfPad(Operation *op) {
11931193
return true;
11941194
}
11951195

1196-
// Returns true if the operation takes no input operands, excluding attributes.
1197-
static bool isNullaryOperation(Operation *op) {
1198-
if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
1199-
isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
1200-
return true;
1201-
return false;
1196+
static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1197+
return llvm::all_of(op->getOperands(), [&](auto operand) {
1198+
Region *operandRegion = operand.getParentRegion();
1199+
return operandRegion && region->isAncestor(operandRegion);
1200+
});
12021201
}
12031202

12041203
bool checkErrorIfCondIf(Operation *op) {
12051204
auto ifOp = dyn_cast<tosa::IfOp>(op);
12061205
if (!ifOp)
12071206
return true;
12081207

1209-
// Whether the types and shapes of operands between the input/output list and
1210-
// internal regions are validated by the operation verifier. However, with
1211-
// support for the simplified form - where redundant operand notations are
1212-
// omitted - is not conformant to the specification. According to the
1213-
// specification, all operands passed into an operation must be explicitly
1214-
// declared at each operation's structure. This code section verify that the
1215-
// operation's form complies with this requirement.
1208+
// Currently the dialect supports declaring cond_if operations that
1209+
// have then/else regions that reference values from outside these
1210+
// regions. According to the specification, all values used by the
1211+
// then/else regions must be explicitly declared within the regions.
1212+
// Therefore we must check that the then/else regions are
1213+
// "isolated from above", in order to be conformant to the
1214+
// specification.
1215+
//
1216+
// Note: the dialect currently supports two styles of syntax for
1217+
// declaring "cond_if" operations. We'll refer to these as follows:
1218+
//
1219+
// Generic:
1220+
// %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1221+
// ^bb0(%arg3, %arg4):
1222+
// tosa.yield %arg3
1223+
// }, {
1224+
// ^bb0(%arg3, %arg4):
1225+
// tosa.yield %arg4
1226+
// })
1227+
//
1228+
// Simplified:
1229+
// %0 = tosa.cond_if %arg2 {
1230+
// tosa.yield %arg0
1231+
// } else {
1232+
// tosa.yield %arg1
1233+
// }
1234+
//
1235+
// Unfortunately, the simplified syntax does not encapsulate values
1236+
// used in then/else regions (see 'simplified' example above), so it
1237+
// must be rewritten to use the generic syntax in order to be conformant
1238+
// to the specification.
12161239

12171240
// Returns true if the region uses no external input operands.
1218-
auto isNullaryRegion = [](Region &region) -> bool {
1241+
auto isIsolatedRegion = [](Region &regionToCheck) -> bool {
12191242
bool noLiveInValue = true;
1220-
region.walk([&noLiveInValue](Operation *op) {
1221-
if (!isNullaryOperation(op)) {
1243+
regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *opInRegion) {
1244+
if (!isOpIsolatedWithinRegion(opInRegion, &regionToCheck)) {
12221245
noLiveInValue = false;
12231246
return WalkResult::interrupt();
12241247
}
@@ -1227,21 +1250,18 @@ bool checkErrorIfCondIf(Operation *op) {
12271250
return noLiveInValue;
12281251
};
12291252

1230-
mlir::Region &thenGraph = ifOp.getThenGraph();
1231-
mlir::Region &elseGraph = ifOp.getElseGraph();
1232-
bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
1233-
bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
1234-
bool isInputListEmpty = ifOp.getInputList().size() == 0;
1235-
1236-
if ((isInputListEmpty != isThenGraphNullaryRegion) ||
1237-
(isInputListEmpty != isElseGraphNullaryRegion)) {
1253+
auto checkIsolatedRegion = [&](Region &regionToCheck,
1254+
StringRef regionName) -> LogicalResult {
1255+
if (isIsolatedRegion(regionToCheck))
1256+
return success();
12381257
op->emitOpError()
1239-
<< "the current simplified form is not strictly conformant to the "
1240-
"spec, please use the generic format\n";
1241-
return false;
1242-
}
1258+
<< "is not conformant to the TOSA specification. It requires the '"
1259+
<< regionName << "' region is isolated from above.\n";
1260+
return failure();
1261+
};
12431262

1244-
return true;
1263+
return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
1264+
failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
12451265
}
12461266

12471267
bool checkErrorIfScatter(Operation *op) {

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,56 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32>
227227
}
228228

229229
// -----
230-
// CHECK-LABEL: cond_if_simplified_form
231-
func.func @test_cond_if_simplified_form(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
232-
// expected-error@+1 {{'tosa.cond_if' op the current simplified form is not strictly conformant to the spec, please use the generic format}}
230+
231+
func.func @test_cond_if_then_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
232+
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
233+
%0 = "tosa.cond_if"(%arg2, %arg1) ({
234+
^bb0(%arg3: tensor<f32>):
235+
tosa.yield %arg1 : tensor<f32>
236+
}, {
237+
^bb0(%arg3: tensor<f32>):
238+
tosa.yield %arg3 : tensor<f32>
239+
}) : (tensor<i1>, tensor<f32>) -> tensor<f32>
240+
return %0 : tensor<f32>
241+
}
242+
243+
// -----
244+
245+
func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
246+
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'else' region is isolated from above.}}
247+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
248+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
249+
tosa.yield %arg3 : tensor<f32>
250+
}, {
251+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
252+
%add = tosa.add %arg0, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
253+
tosa.yield %add : tensor<f32>
254+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
255+
return %0 : tensor<f32>
256+
}
257+
258+
// -----
259+
260+
func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
261+
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
233262
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
234-
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
235-
tosa.yield %1 : tensor<f32>
263+
tosa.yield %arg0 : tensor<f32>
236264
} else {
237-
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
238-
tosa.yield %1 : tensor<f32>
265+
tosa.yield %arg1 : tensor<f32>
239266
}
240267
return %0 : tensor<f32>
241268
}
269+
270+
// -----
271+
272+
// Check isolated cond_if's are valid
273+
func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
274+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
275+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
276+
tosa.yield %arg3 : tensor<f32>
277+
}, {
278+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
279+
tosa.yield %arg4 : tensor<f32>
280+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
281+
return %0 : tensor<f32>
282+
}

0 commit comments

Comments
 (0)