@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
562562
563563 bool CheckVariable (Operation *op);
564564 bool CheckVariableReadOrWrite (Operation *op);
565- bool isValidElementType (Type type);
565+ bool isValidElementType (Type type, const bool allowUnsigned = false );
566566
567567 SmallVector<
568568 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
11761176 return success ();
11771177}
11781178
1179- bool TosaValidation::isValidElementType (Type type) {
1179+ bool TosaValidation::isValidElementType (Type type, const bool allowUnsigned ) {
11801180 if (isa<FloatType>(type)) {
11811181 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
11821182 Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
11911191 case 48 :
11921192 return true ;
11931193 }
1194+ } else if (allowUnsigned && intTy.isUnsigned ()) {
1195+ switch (intTy.getWidth ()) {
1196+ case 8 :
1197+ case 16 :
1198+ case 32 :
1199+ return true ;
1200+ }
11941201 }
11951202 } else if (mlir::isa<tosa::shapeType>(type)) {
11961203 return true ;
@@ -1209,19 +1216,23 @@ void TosaValidation::runOnOperation() {
12091216 if (op->getDialect () != tosaDialect)
12101217 return ;
12111218
1212- // perform valid element type check at the beginning to
1213- // protect rest of code against quantized element types
1219+ // validate operator element types:
1220+ // - rescale operator is allowed to have ui8/ui16/ui32
1221+ // operands/results
1222+ // - perform valid element type check at the beginning to
1223+ // protect rest of code against quantized element types
1224+ const bool opIsRescale = isa<tosa::RescaleOp>(op);
12141225 for (Value operand : op->getOperands ()) {
12151226 auto elementTy = getElementTypeOrSelf (operand);
1216- if (!isValidElementType (elementTy)) {
1227+ if (!isValidElementType (elementTy, opIsRescale )) {
12171228 op->emitOpError () << " is not profile-aligned: element type "
12181229 << elementTy << " is not legal" ;
12191230 return signalPassFailure ();
12201231 }
12211232 }
12221233 for (Type resultTy : op->getResultTypes ()) {
12231234 auto elementTy = getElementTypeOrSelf (resultTy);
1224- if (!isValidElementType (elementTy)) {
1235+ if (!isValidElementType (elementTy, opIsRescale )) {
12251236 op->emitOpError () << " is not profile-aligned: element type "
12261237 << elementTy << " is not legal" ;
12271238 return signalPassFailure ();
0 commit comments