@@ -572,6 +572,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
572572 return success ();
573573}
574574
575+ // Verify whether same type and shape of the given two types.
576+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, Type type1,
577+ StringRef name1, Type type2,
578+ StringRef name2) {
579+ auto shapeType1 = dyn_cast<ShapedType>(type1);
580+ auto shapeType2 = dyn_cast<ShapedType>(type2);
581+ if (!shapeType1 || !shapeType2)
582+ return failure ();
583+
584+ auto elemType1 = shapeType1.getElementType ();
585+ auto elemType2 = shapeType2.getElementType ();
586+ if (elemType1 != elemType2)
587+ return op->emitOpError ()
588+ << " require same element type for " << name1 << " (" << elemType1
589+ << " ) and " << name2 << " (" << elemType2 << " )" ;
590+
591+ if (failed (verifyCompatibleShape (type1, type2)))
592+ return op->emitOpError ()
593+ << " require same shapes for " << name1 << " (" << type1 << " ) and "
594+ << name2 << " (" << type2 << " )" ;
595+
596+ return success ();
597+ }
598+
599+ // Verify whether same length, type, and shape of the given two tensor lists.
600+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, ValueRange list1,
601+ StringRef name1,
602+ ValueRange list2,
603+ StringRef name2) {
604+ if (list1.size () != list2.size ())
605+ return op->emitOpError ()
606+ << " require same number of values in " << name1 << " ("
607+ << list1.size () << " ) and " << name2 << " (" << list2.size () << " )" ;
608+
609+ for (auto [type1, type2] :
610+ llvm::zip_equal (list1.getTypes (), list2.getTypes ())) {
611+ if (errorIfTypeOrShapeMismatch (op, type1, name1, type2, name2).failed ())
612+ return failure ();
613+ }
614+
615+ return success ();
616+ }
617+
618+ static inline LogicalResult errorIfShapeNotSizeOne (Operation *op, Type type) {
619+ ShapeAdaptor shapeAdaptor (type);
620+ if (!shapeAdaptor.hasRank () || !shapeAdaptor.hasStaticShape ())
621+ return success ();
622+
623+ return shapeAdaptor.getNumElements () == 1 ? success () : failure ();
624+ }
625+
575626// verify that inType and outType have same element types
576627template <typename T>
577628static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -3397,6 +3448,84 @@ void IfOp::print(OpAsmPrinter &p) {
33973448 p.printOptionalAttrDict ((*this )->getAttrs ());
33983449}
33993450
3451+ LogicalResult IfOp::verify () {
3452+ if (errorIfTypeOrShapeMismatch (*this , getThenGraph ().front ().getArguments (),
3453+ " 'then_graph' arguments" , getInputList (),
3454+ " 'input_list'" )
3455+ .failed ())
3456+ return failure ();
3457+
3458+ if (errorIfTypeOrShapeMismatch (*this , getElseGraph ().front ().getArguments (),
3459+ " 'else_graph' arguments" , getInputList (),
3460+ " 'input_list'" )
3461+ .failed ())
3462+ return failure ();
3463+
3464+ auto thenYield = cast<tosa::YieldOp>(getThenGraph ().front ().getTerminator ());
3465+ if (errorIfTypeOrShapeMismatch (*this , thenYield.getInputs (),
3466+ " 'then_graph' results" , getOutputList (),
3467+ " 'output_list'" )
3468+ .failed ())
3469+ return failure ();
3470+
3471+ auto elseYield = cast<tosa::YieldOp>(getElseGraph ().front ().getTerminator ());
3472+ if (errorIfTypeOrShapeMismatch (*this , elseYield.getInputs (),
3473+ " 'else_graph' results" , getOutputList (),
3474+ " 'output_list'" )
3475+ .failed ())
3476+ return failure ();
3477+
3478+ auto condType = getCondition ().getType ();
3479+ if (errorIfShapeNotSizeOne (*this , condType).failed ())
3480+ return emitOpError () << " 'condition' must be a size 1 tensor, got "
3481+ << condType;
3482+
3483+ return success ();
3484+ }
3485+
3486+ LogicalResult WhileOp::verify () {
3487+ if (errorIfTypeOrShapeMismatch (*this , getInputList (), " 'input_list'" ,
3488+ getOutputList (), " 'output_list'" )
3489+ .failed ())
3490+ return failure ();
3491+
3492+ if (errorIfTypeOrShapeMismatch (*this , getCondGraph ().front ().getArguments (),
3493+ " 'cond_graph' arguments" , getInputList (),
3494+ " 'input_list'" )
3495+ .failed ())
3496+ return failure ();
3497+
3498+ if (errorIfTypeOrShapeMismatch (*this , getBodyGraph ().front ().getArguments (),
3499+ " 'body_graph' arguments" , getInputList (),
3500+ " 'input_list'" )
3501+ .failed ())
3502+ return failure ();
3503+
3504+ auto bodyYield = cast<tosa::YieldOp>(getBodyGraph ().front ().getTerminator ());
3505+ if (errorIfTypeOrShapeMismatch (*this , bodyYield.getInputs (),
3506+ " 'body_graph' results" , getInputList (),
3507+ " 'input_list'" )
3508+ .failed ())
3509+ return failure ();
3510+
3511+ // Condition block output must be a single element tensor with a single bool
3512+ // value.
3513+ auto condYield = cast<tosa::YieldOp>(getCondGraph ().front ().getTerminator ());
3514+ if (condYield.getInputs ().size () != 1 )
3515+ return emitOpError () << " require 'cond_graph' only have one result" ;
3516+
3517+ auto condOutType = condYield.getInputs ()[0 ].getType ();
3518+ if (errorIfShapeNotSizeOne (*this , condOutType).failed ())
3519+ return emitOpError () << " 'cond_graph' result must be a size 1 tensor, got "
3520+ << condOutType;
3521+
3522+ if (!getElementTypeOrSelf (condOutType).isInteger (1 ))
3523+ return emitOpError () << " 'cond_graph' result must be a boolean tensor, got "
3524+ << condOutType;
3525+
3526+ return success ();
3527+ }
3528+
34003529LogicalResult ReverseOp::verify () {
34013530 if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
34023531 /* outType = */ getOutput ().getType ())
0 commit comments