@@ -562,6 +562,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
562562 return success ();
563563}
564564
565+ // Verify whether same type and shape of the given two types.
566+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, Type type1,
567+ StringRef name1, Type type2,
568+ StringRef name2) {
569+ auto shapeType1 = dyn_cast<ShapedType>(type1);
570+ auto shapeType2 = dyn_cast<ShapedType>(type2);
571+ if (!shapeType1 || !shapeType2)
572+ return failure ();
573+
574+ auto elemType1 = shapeType1.getElementType ();
575+ auto elemType2 = shapeType2.getElementType ();
576+ if (elemType1 != elemType2)
577+ return op->emitOpError ()
578+ << " require same element type for " << name1 << " (" << elemType1
579+ << " ) and " << name2 << " (" << elemType2 << " )" ;
580+
581+ if (failed (verifyCompatibleShape (type1, type2)))
582+ return op->emitOpError ()
583+ << " require same shapes for " << name1 << " (" << type1 << " ) and "
584+ << name2 << " (" << type2 << " )" ;
585+
586+ return success ();
587+ }
588+
589+ // Verify whether same length, type, and shape of the given two tensor lists.
590+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, ValueRange list1,
591+ StringRef name1,
592+ ValueRange list2,
593+ StringRef name2) {
594+ if (list1.size () != list2.size ())
595+ return op->emitOpError ()
596+ << " require same number of values in " << name1 << " ("
597+ << list1.size () << " ) and " << name2 << " (" << list2.size () << " )" ;
598+
599+ for (auto [type1, type2] :
600+ llvm::zip_equal (list1.getTypes (), list2.getTypes ())) {
601+ if (errorIfTypeOrShapeMismatch (op, type1, name1, type2, name2).failed ())
602+ return failure ();
603+ }
604+
605+ return success ();
606+ }
607+
608+ static inline LogicalResult errorIfShapeNotSizeOne (Operation *op, Type type) {
609+ ShapeAdaptor shapeAdaptor (type);
610+ if (!shapeAdaptor.hasRank () || !shapeAdaptor.hasStaticShape ())
611+ return success ();
612+
613+ return shapeAdaptor.getNumElements () == 1 ? success () : failure ();
614+ }
615+
565616// verify that inType and outType have same element types
566617template <typename T>
567618static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -3437,6 +3488,84 @@ void IfOp::print(OpAsmPrinter &p) {
34373488 p.printOptionalAttrDict ((*this )->getAttrs ());
34383489}
34393490
3491+ LogicalResult IfOp::verify () {
3492+ if (errorIfTypeOrShapeMismatch (*this , getThenGraph ().front ().getArguments (),
3493+ " 'then_graph' arguments" , getInputList (),
3494+ " 'input_list'" )
3495+ .failed ())
3496+ return failure ();
3497+
3498+ if (errorIfTypeOrShapeMismatch (*this , getElseGraph ().front ().getArguments (),
3499+ " 'else_graph' arguments" , getInputList (),
3500+ " 'input_list'" )
3501+ .failed ())
3502+ return failure ();
3503+
3504+ auto thenYield = cast<tosa::YieldOp>(getThenGraph ().front ().getTerminator ());
3505+ if (errorIfTypeOrShapeMismatch (*this , thenYield.getInputs (),
3506+ " 'then_graph' results" , getOutputList (),
3507+ " 'output_list'" )
3508+ .failed ())
3509+ return failure ();
3510+
3511+ auto elseYield = cast<tosa::YieldOp>(getElseGraph ().front ().getTerminator ());
3512+ if (errorIfTypeOrShapeMismatch (*this , elseYield.getInputs (),
3513+ " 'else_graph' results" , getOutputList (),
3514+ " 'output_list'" )
3515+ .failed ())
3516+ return failure ();
3517+
3518+ auto condType = getCondition ().getType ();
3519+ if (errorIfShapeNotSizeOne (*this , condType).failed ())
3520+ return emitOpError () << " 'condition' must be a size 1 tensor, got "
3521+ << condType;
3522+
3523+ return success ();
3524+ }
3525+
3526+ LogicalResult WhileOp::verify () {
3527+ if (errorIfTypeOrShapeMismatch (*this , getInputList (), " 'input_list'" ,
3528+ getOutputList (), " 'output_list'" )
3529+ .failed ())
3530+ return failure ();
3531+
3532+ if (errorIfTypeOrShapeMismatch (*this , getCondGraph ().front ().getArguments (),
3533+ " 'cond_graph' arguments" , getInputList (),
3534+ " 'input_list'" )
3535+ .failed ())
3536+ return failure ();
3537+
3538+ if (errorIfTypeOrShapeMismatch (*this , getBodyGraph ().front ().getArguments (),
3539+ " 'body_graph' arguments" , getInputList (),
3540+ " 'input_list'" )
3541+ .failed ())
3542+ return failure ();
3543+
3544+ auto bodyYield = cast<tosa::YieldOp>(getBodyGraph ().front ().getTerminator ());
3545+ if (errorIfTypeOrShapeMismatch (*this , bodyYield.getInputs (),
3546+ " 'body_graph' results" , getInputList (),
3547+ " 'input_list'" )
3548+ .failed ())
3549+ return failure ();
3550+
3551+ // Condition block output must be a single element tensor with a single bool
3552+ // value.
3553+ auto condYield = cast<tosa::YieldOp>(getCondGraph ().front ().getTerminator ());
3554+ if (condYield.getInputs ().size () != 1 )
3555+ return emitOpError () << " require 'cond_graph' only have one result" ;
3556+
3557+ auto condOutType = condYield.getInputs ()[0 ].getType ();
3558+ if (errorIfShapeNotSizeOne (*this , condOutType).failed ())
3559+ return emitOpError () << " 'cond_graph' result must be a size 1 tensor, got "
3560+ << condOutType;
3561+
3562+ if (!getElementTypeOrSelf (condOutType).isInteger (1 ))
3563+ return emitOpError () << " 'cond_graph' result must be a boolean tensor, got "
3564+ << condOutType;
3565+
3566+ return success ();
3567+ }
3568+
34403569LogicalResult ReverseOp::verify () {
34413570 if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
34423571 /* outType = */ getOutput ().getType ())
0 commit comments