@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
104104 }
105105
106106 LogicalResult applyLevelCheck (Operation *op);
107+ LogicalResult applyAttributeCheck (Operation *op);
107108
108109 // check variable read/write data types against variable declarations
109110 LogicalResult applyVariableCheck (Operation *op);
@@ -386,6 +387,25 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386387 return true ;
387388 }
388389
390+ bool attributeCheckRescale (Operation *op) {
391+ if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
392+ if (rescale.getRoundingMode () == " DOUBLE_ROUND" &&
393+ !targetEnv.allows (Extension::doubleround)) {
394+ op->emitOpError ()
395+ << " failed attribute check: rounding_mode = DOUBLE_ROUND "
396+ << " requires extension [doubleround]" ;
397+ return false ;
398+ } else if (rescale.getRoundingMode () == " INEXACT_ROUND" &&
399+ !targetEnv.allows (Extension::inexactround)) {
400+ op->emitOpError ()
401+ << " failed attribute check: rounding_mode = INEXACT_ROUND "
402+ << " requires extension [inexactround]" ;
403+ return false ;
404+ }
405+ }
406+ return true ;
407+ }
408+
389409 // configure profile and level values from pass options profileName and
390410 // levelName
391411 void configLevelAndProfile () {
@@ -415,7 +435,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
415435 } else {
416436 llvm::errs () << " unknown TOSA extension name passed in: " << ext
417437 << " , supported extension are int16, int4, bf16, "
418- << " fp8e4m3, fp8e5m2, fft, variable and controlflow\n " ;
438+ << " fp8e4m3, fp8e5m2, fft, variable, controlflow, "
439+ << " doubleround and inexactround\n " ;
419440 return signalPassFailure ();
420441 }
421442 }
@@ -642,6 +663,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
642663 return success ();
643664}
644665
666+ LogicalResult TosaValidation::applyAttributeCheck (Operation *op) {
667+ if (!attributeCheckRescale (op))
668+ return failure ();
669+ return success ();
670+ }
671+
645672inline bool CompatibleTypes (const mlir::Type &type,
646673 const mlir::Type &declaredType) {
647674 // for now, simply use type equality comparison
@@ -936,6 +963,10 @@ void TosaValidation::runOnOperation() {
936963 if (failed (applyLevelCheck (op)))
937964 signalPassFailure ();
938965
966+ // check additional attribute restrictions
967+ if (failed (applyAttributeCheck (op)))
968+ signalPassFailure ();
969+
939970 // do variable type checks
940971 if (failed (applyVariableCheck (op)))
941972 signalPassFailure ();
0 commit comments