Skip to content

Commit 2d1fae0

Browse files
committed
[mlir][tosa] Improve invalid operator data types error message
The error message on invalid operator data types in the validation pass was not very clear. This commit improves the error message as follows: Current: ``` 'tosa.add' op illegal: operand/result data types not supported ``` Improved: ``` 'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification. ``` Change-Id: I844bdc27f16100b1434c85fa9b79f756ae519c8c
1 parent 72b2219 commit 2d1fae0

File tree

4 files changed

+68
-7
lines changed

4 files changed

+68
-7
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ class TosaProfileCompliance {
164164
SmallVector<StringRef>
165165
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
166166

167+
static llvm::SmallString<7> stringifyTypeInfo(const TypeInfo &typeInfo);
168+
167169
private:
168170
template <typename T>
169171
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,52 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
485485
CheckCondition condition = CheckCondition::invalid;
486486
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
487487
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
488+
488489
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
489-
!maybeProfDef.value().size() && !maybeExtDef.value().size())
490+
!maybeProfDef.value().size() && !maybeExtDef.value().size()) {
491+
std::string message;
492+
llvm::raw_string_ostream os(message);
493+
os << "illegal: operation operand/result data types did not align with any "
494+
"profile or extension, got (";
495+
496+
ProfileInfoDepot depot(op);
497+
SmallVector<TypeInfo> current = depot.getInfo();
498+
for (const auto &typeInfo : llvm::drop_end(current))
499+
os << stringifyTypeInfo(typeInfo) << ",";
500+
os << stringifyTypeInfo(current.back()) << ")";
501+
502+
// avoid polluting the error message output by outputting only
503+
// the best match
504+
const std::string opName = op->getName().getStringRef().str();
505+
int maxMatches = -1;
506+
SmallVector<TypeInfo> bestTypeInfo;
507+
const auto searchBestMatch = [&](auto map) {
508+
for (const auto &complianceInfos : map[opName]) {
509+
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
510+
const int matches = llvm::count_if(
511+
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
512+
return isSameTypeInfo(std::get<0>(zipType),
513+
std::get<1>(zipType));
514+
});
515+
if (matches > maxMatches) {
516+
maxMatches = matches;
517+
bestTypeInfo = typeInfos;
518+
}
519+
}
520+
}
521+
};
522+
searchBestMatch(getProfileComplianceMap<Profile>());
523+
searchBestMatch(getProfileComplianceMap<Extension>());
524+
525+
os << ", did you mean (";
526+
for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
527+
os << stringifyTypeInfo(typeInfo) << ",";
528+
os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
529+
os << "Otherwise, please refer to the 'supported data types' for '"
530+
<< opName << "' in the specification.";
531+
op->emitOpError(message);
490532
return failure();
533+
}
491534

492535
return success();
493536
}
@@ -562,3 +605,21 @@ SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
562605

563606
return debugStrings;
564607
}
608+
609+
llvm::SmallString<7>
610+
TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
611+
if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
612+
return {"i" + llvm::utostr(typeInfo.bitWidth)};
613+
} else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
614+
return {"f16"};
615+
} else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
616+
return {"f32"};
617+
} else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
618+
return {"bf16"};
619+
} else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
620+
return {"fp8e4m3"};
621+
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
622+
return {"fp8e5m2"};
623+
}
624+
llvm_unreachable("unknown type");
625+
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,10 +1248,8 @@ void TosaValidation::runOnOperation() {
12481248
return signalPassFailure();
12491249

12501250
if (!allowInvalidOpDatatypeCombinations &&
1251-
failed(profileComp.checkInvalid(op))) {
1252-
op->emitOpError("illegal: operand/result data types not supported");
1251+
failed(profileComp.checkInvalid(op)))
12531252
return signalPassFailure();
1254-
}
12551253

12561254
// Some uses of TOSA rely on the constant operands of particular
12571255
// operations.

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2:
3535

3636
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
3737
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
38-
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
38+
// expected-error@+1 {{'tosa.conv2d' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i8,i8,i8,i32,i8), did you mean (i8,i8,i32,i8,i8,i32,i32)?}}
3939
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
4040
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
4141
return %0 : tensor<1x27x27x16xi8>
@@ -1888,7 +1888,7 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
18881888

18891889
// CHECK-LABEL: test_add_i1
18901890
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
1891-
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
1891+
// expected-error@+1 {{'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification.}}
18921892
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
18931893
return %0 : tensor<13x21x3xi1>
18941894
}
@@ -1897,7 +1897,7 @@ func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) ->
18971897

18981898
// CHECK-LABEL: test_mul_out_i16
18991899
func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
1900-
// expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
1900+
// expected-error@+1 {{'tosa.mul' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i16), did you mean (i8,i8,i32)?}}
19011901
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
19021902
return %0 : tensor<13x21x3xi16>
19031903
}

0 commit comments

Comments
 (0)