@@ -44,50 +44,67 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
4444
4545LogicalResult transform::smt::ConstrainParamsOp::verify () {
4646 auto yieldTerminator =
47- llvm::dyn_cast_if_present <mlir::smt::YieldOp>(getRegion ().front ().back ());
47+ dyn_cast <mlir::smt::YieldOp>(getRegion ().front ().back ());
4848 if (!yieldTerminator)
4949 return emitOpError () << " expected '"
5050 << mlir::smt::YieldOp::getOperationName ()
5151 << " ' as terminator" ;
5252
53+ auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
54+ Type paramType, StringRef paramDesc,
55+ auto *atOp) -> InFlightDiagnostic {
56+ if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
57+ smtType))
58+ return atOp->emitOpError () << " the type of " << smtDesc << " #" << idx
59+ << " is expected to be either a !smt.bool, a "
60+ " !smt.int, or a !smt.bv" ;
61+
62+ assert (isa<TransformParamTypeInterface>(paramType) &&
63+ " ODS specifies params' type should implement param interface" );
64+ if (isa<transform::AnyParamType>(paramType))
65+ return {}; // No further checks can be done.
66+
67+ // NB: This cast must succeed as long as the only implementors of
68+ // TransformParamTypeInterface are AnyParamType and ParamType.
69+ Type typeWrappedByParam = cast<ParamType>(paramType).getType ();
70+
71+ if (isa<mlir::smt::IntType>(smtType)) {
72+ if (!isa<IntegerType>(typeWrappedByParam))
73+ return atOp->emitOpError ()
74+ << " the type of " << smtDesc << " #" << idx
75+ << " is !smt.int though the corresponding " << paramDesc
76+ << " type (" << paramType << " ) is not wrapping an integer type" ;
77+ } else if (isa<mlir::smt::BoolType>(smtType)) {
78+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
79+ if (!wrappedIntType || wrappedIntType.getWidth () != 1 )
80+ return atOp->emitOpError ()
81+ << " the type of " << smtDesc << " #" << idx
82+ << " is !smt.bool though the corresponding " << paramDesc
83+ << " type (" << paramType << " ) is not wrapping i1" ;
84+ } else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
85+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
86+ if (!wrappedIntType || wrappedIntType.getWidth () != bvSmtType.getWidth ())
87+ return atOp->emitOpError ()
88+ << " the type of " << smtDesc << " #" << idx << " is " << smtType
89+ << " though the corresponding " << paramDesc << " type ("
90+ << paramType
91+ << " ) is not wrapping an integer type of the same bitwidth" ;
92+ }
93+
94+ return {};
95+ };
96+
5397 if (getOperands ().size () != getBody ().getNumArguments ())
5498 return emitOpError (
5599 " must have the same number of block arguments as operands" );
56100
57- for (auto [i, operandType, blockArgType] :
58- llvm::zip_equal (llvm::seq<unsigned >(0 , getBody ().getNumArguments ()),
59- getOperandTypes (), getBody ().getArgumentTypes ())) {
60- if (isa<transform::AnyParamType>(operandType))
61- continue ; // No type checking as operand is of !transform.any_param type.
62- auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
63- if (!paramOperandType)
64- return emitOpError () << " operand type #" << i
65- << " is not a !transform.param" ;
66- Type wrappedOperandType = paramOperandType.getType ();
67-
68- if (isa<mlir::smt::IntType>(blockArgType)) {
69- if (!isa<IntegerType>(paramOperandType.getType ()))
70- return emitOpError ()
71- << " the type of block arg #" << i
72- << " is !smt.int though the corresponding operand type ("
73- << operandType << " ) is not wrapping an integer type" ;
74- } else if (isa<mlir::smt::BoolType>(blockArgType)) {
75- auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
76- if (!intOperandType || intOperandType.getWidth () != 1 )
77- return emitOpError ()
78- << " the type of block arg #" << i
79- << " is !smt.bool though the corresponding operand type ("
80- << operandType << " ) is not wrapping i1 (i.e. bool)" ;
81- } else if (auto bvBlockArgType =
82- dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
83- auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
84- if (!intOperandType ||
85- intOperandType.getWidth () != bvBlockArgType.getWidth ())
86- return emitOpError ()
87- << " the type of block arg #" << i << " is " << blockArgType
88- << " though the corresponding operand type (" << operandType
89- << " ) is not wrapping an integer type of the same bitwidth" ;
90- }
101+ for (auto [idx, operandType, blockArgType] :
102+ llvm::enumerate (getOperandTypes (), getBody ().getArgumentTypes ())) {
103+ InFlightDiagnostic typeCheckResult =
104+ checkTypes (idx, blockArgType, " block arg" , operandType, " operand" ,
105+ /* atOp=*/ this );
106+ if (LogicalResult (typeCheckResult).failed ())
107+ return typeCheckResult;
91108 }
92109
93110 for (auto &op : getBody ().getOps ()) {
@@ -96,52 +113,19 @@ LogicalResult transform::smt::ConstrainParamsOp::verify() {
96113 " ops contained in region should belong to SMT-dialect" );
97114 }
98115
99- if (getOperands ().size () != getBody ().getNumArguments ())
100- return emitOpError (
101- " must have the same number of block arguments as operands" );
102-
103116 if (yieldTerminator->getNumOperands () != getNumResults ())
104117 return yieldTerminator.emitOpError ()
105118 << " expected terminator to have as many operands as the parent op "
106119 " has results" ;
107120
108- for (auto [i, termOperandType, resultType] : llvm::zip_equal (
109- llvm::seq<unsigned >(0 , yieldTerminator->getNumOperands ()),
121+ for (auto [idx, termOperandType, resultType] : llvm::enumerate (
110122 yieldTerminator->getOperands ().getType (), getResultTypes ())) {
111- if (isa<transform::AnyParamType>(resultType))
112- continue ; // No type checking as result is of !transform.any_param type.
113- auto paramResultType = dyn_cast<transform::ParamType>(resultType);
114- if (!paramResultType)
115- return emitOpError () << " result type #" << i
116- << " is not a !transform.param" ;
117- Type wrappedResultType = paramResultType.getType ();
118-
119- if (isa<mlir::smt::IntType>(termOperandType)) {
120- if (!isa<IntegerType>(wrappedResultType))
121- return yieldTerminator.emitOpError ()
122- << " the type of terminator operand #" << i
123- << " is !smt.int though the corresponding result type ("
124- << resultType
125- << " ) of the parent op is not wrapping an integer type" ;
126- } else if (isa<mlir::smt::BoolType>(termOperandType)) {
127- auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
128- if (!intResultType || intResultType.getWidth () != 1 )
129- return yieldTerminator.emitOpError ()
130- << " the type of terminator operand #" << i
131- << " is !smt.bool though the corresponding result type ("
132- << resultType
133- << " ) of the parent op is not wrapping i1 (i.e. bool)" ;
134- } else if (auto bvOperandType =
135- dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
136- auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
137- if (!intResultType ||
138- intResultType.getWidth () != bvOperandType.getWidth ())
139- return yieldTerminator.emitOpError ()
140- << " the type of terminator operand #" << i << " is "
141- << termOperandType << " though the corresponding result type ("
142- << resultType
143- << " ) is not wrapping an integer type of the same bitwidth" ;
144- }
123+ InFlightDiagnostic typeCheckResult =
124+ checkTypes (idx, termOperandType, " terminator operand" ,
125+ cast<transform::ParamType>(resultType), " result" ,
126+ /* atOp=*/ &yieldTerminator);
127+ if (LogicalResult (typeCheckResult).failed ())
128+ return typeCheckResult;
145129 }
146130
147131 return success ();
0 commit comments