Skip to content

Commit b060217

Browse files
committed
Address Alex's comments
1 parent 9e8a2e2 commit b060217

File tree

2 files changed

+131
-86
lines changed

2 files changed

+131
-86
lines changed

mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp

Lines changed: 59 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -44,50 +44,67 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
4444

4545
LogicalResult transform::smt::ConstrainParamsOp::verify() {
4646
auto yieldTerminator =
47-
llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
47+
dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
4848
if (!yieldTerminator)
4949
return emitOpError() << "expected '"
5050
<< mlir::smt::YieldOp::getOperationName()
5151
<< "' as terminator";
5252

53+
auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
54+
Type paramType, StringRef paramDesc,
55+
auto *atOp) -> InFlightDiagnostic {
56+
if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
57+
smtType))
58+
return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
59+
<< " is expected to be either a !smt.bool, a "
60+
"!smt.int, or a !smt.bv";
61+
62+
assert(isa<TransformParamTypeInterface>(paramType) &&
63+
"ODS specifies params' type should implement param interface");
64+
if (isa<transform::AnyParamType>(paramType))
65+
return {}; // No further checks can be done.
66+
67+
// NB: This cast must succeed as long as the only implementors of
68+
// TransformParamTypeInterface are AnyParamType and ParamType.
69+
Type typeWrappedByParam = cast<ParamType>(paramType).getType();
70+
71+
if (isa<mlir::smt::IntType>(smtType)) {
72+
if (!isa<IntegerType>(typeWrappedByParam))
73+
return atOp->emitOpError()
74+
<< "the type of " << smtDesc << " #" << idx
75+
<< " is !smt.int though the corresponding " << paramDesc
76+
<< " type (" << paramType << ") is not wrapping an integer type";
77+
} else if (isa<mlir::smt::BoolType>(smtType)) {
78+
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
79+
if (!wrappedIntType || wrappedIntType.getWidth() != 1)
80+
return atOp->emitOpError()
81+
<< "the type of " << smtDesc << " #" << idx
82+
<< " is !smt.bool though the corresponding " << paramDesc
83+
<< " type (" << paramType << ") is not wrapping i1";
84+
} else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
85+
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
86+
if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
87+
return atOp->emitOpError()
88+
<< "the type of " << smtDesc << " #" << idx << " is " << smtType
89+
<< " though the corresponding " << paramDesc << " type ("
90+
<< paramType
91+
<< ") is not wrapping an integer type of the same bitwidth";
92+
}
93+
94+
return {};
95+
};
96+
5397
if (getOperands().size() != getBody().getNumArguments())
5498
return emitOpError(
5599
"must have the same number of block arguments as operands");
56100

57-
for (auto [i, operandType, blockArgType] :
58-
llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
59-
getOperandTypes(), getBody().getArgumentTypes())) {
60-
if (isa<transform::AnyParamType>(operandType))
61-
continue; // No type checking as operand is of !transform.any_param type.
62-
auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
63-
if (!paramOperandType)
64-
return emitOpError() << "operand type #" << i
65-
<< " is not a !transform.param";
66-
Type wrappedOperandType = paramOperandType.getType();
67-
68-
if (isa<mlir::smt::IntType>(blockArgType)) {
69-
if (!isa<IntegerType>(paramOperandType.getType()))
70-
return emitOpError()
71-
<< "the type of block arg #" << i
72-
<< " is !smt.int though the corresponding operand type ("
73-
<< operandType << ") is not wrapping an integer type";
74-
} else if (isa<mlir::smt::BoolType>(blockArgType)) {
75-
auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
76-
if (!intOperandType || intOperandType.getWidth() != 1)
77-
return emitOpError()
78-
<< "the type of block arg #" << i
79-
<< " is !smt.bool though the corresponding operand type ("
80-
<< operandType << ") is not wrapping i1 (i.e. bool)";
81-
} else if (auto bvBlockArgType =
82-
dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
83-
auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
84-
if (!intOperandType ||
85-
intOperandType.getWidth() != bvBlockArgType.getWidth())
86-
return emitOpError()
87-
<< "the type of block arg #" << i << " is " << blockArgType
88-
<< " though the corresponding operand type (" << operandType
89-
<< ") is not wrapping an integer type of the same bitwidth";
90-
}
101+
for (auto [idx, operandType, blockArgType] :
102+
llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
103+
InFlightDiagnostic typeCheckResult =
104+
checkTypes(idx, blockArgType, "block arg", operandType, "operand",
105+
/*atOp=*/this);
106+
if (LogicalResult(typeCheckResult).failed())
107+
return typeCheckResult;
91108
}
92109

93110
for (auto &op : getBody().getOps()) {
@@ -96,52 +113,19 @@ LogicalResult transform::smt::ConstrainParamsOp::verify() {
96113
"ops contained in region should belong to SMT-dialect");
97114
}
98115

99-
if (getOperands().size() != getBody().getNumArguments())
100-
return emitOpError(
101-
"must have the same number of block arguments as operands");
102-
103116
if (yieldTerminator->getNumOperands() != getNumResults())
104117
return yieldTerminator.emitOpError()
105118
<< "expected terminator to have as many operands as the parent op "
106119
"has results";
107120

108-
for (auto [i, termOperandType, resultType] : llvm::zip_equal(
109-
llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
121+
for (auto [idx, termOperandType, resultType] : llvm::enumerate(
110122
yieldTerminator->getOperands().getType(), getResultTypes())) {
111-
if (isa<transform::AnyParamType>(resultType))
112-
continue; // No type checking as result is of !transform.any_param type.
113-
auto paramResultType = dyn_cast<transform::ParamType>(resultType);
114-
if (!paramResultType)
115-
return emitOpError() << "result type #" << i
116-
<< " is not a !transform.param";
117-
Type wrappedResultType = paramResultType.getType();
118-
119-
if (isa<mlir::smt::IntType>(termOperandType)) {
120-
if (!isa<IntegerType>(wrappedResultType))
121-
return yieldTerminator.emitOpError()
122-
<< "the type of terminator operand #" << i
123-
<< " is !smt.int though the corresponding result type ("
124-
<< resultType
125-
<< ") of the parent op is not wrapping an integer type";
126-
} else if (isa<mlir::smt::BoolType>(termOperandType)) {
127-
auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
128-
if (!intResultType || intResultType.getWidth() != 1)
129-
return yieldTerminator.emitOpError()
130-
<< "the type of terminator operand #" << i
131-
<< " is !smt.bool though the corresponding result type ("
132-
<< resultType
133-
<< ") of the parent op is not wrapping i1 (i.e. bool)";
134-
} else if (auto bvOperandType =
135-
dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
136-
auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
137-
if (!intResultType ||
138-
intResultType.getWidth() != bvOperandType.getWidth())
139-
return yieldTerminator.emitOpError()
140-
<< "the type of terminator operand #" << i << " is "
141-
<< termOperandType << " though the corresponding result type ("
142-
<< resultType
143-
<< ") is not wrapping an integer type of the same bitwidth";
144-
}
123+
InFlightDiagnostic typeCheckResult =
124+
checkTypes(idx, termOperandType, "terminator operand",
125+
cast<transform::ParamType>(resultType), "result",
126+
/*atOp=*/&yieldTerminator);
127+
if (LogicalResult(typeCheckResult).failed())
128+
return typeCheckResult;
145129
}
146130

147131
return success();

mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
22

3-
// CHECK-LABEL: @constraint_not_using_smt_ops
3+
// CHECK-LABEL: @incorrect terminator
44
module attributes {transform.with_named_sequence} {
5-
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
5+
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
66
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
7-
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
7+
// expected-error@below {{op expected 'smt.yield' as terminator}}
88
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
99
^bb0(%param_as_smt_var: !smt.int):
10-
%c4 = arith.constant 4 : i32
11-
// This is the kind of thing one might think works:
12-
//arith.remsi %param_as_smt_var, %c4 : i32
10+
transform.yield
1311
}
1412
transform.yield
1513
}
@@ -31,6 +29,23 @@ module attributes {transform.with_named_sequence} {
3129

3230
// -----
3331

32+
// CHECK-LABEL: @constraint_not_using_smt_ops
33+
module attributes {transform.with_named_sequence} {
34+
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
35+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
36+
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
37+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
38+
^bb0(%param_as_smt_var: !smt.int):
39+
%c4 = arith.constant 4 : i32
40+
// This is the kind of thing one might think works:
41+
//arith.remsi %param_as_smt_var, %c4 : i32
42+
}
43+
transform.yield
44+
}
45+
}
46+
47+
// -----
48+
3449
// CHECK-LABEL: @results_not_one_to_one_with_vars
3550
module attributes {transform.with_named_sequence} {
3651
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
@@ -46,11 +61,27 @@ module attributes {transform.with_named_sequence} {
4661

4762
// -----
4863

49-
// CHECK-LABEL: @mismatched_type_bool
64+
// CHECK-LABEL: @non_smt_type_block_args
5065
module attributes {transform.with_named_sequence} {
51-
transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
66+
transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) {
67+
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
68+
// expected-error@below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}}
69+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i8>) {
70+
^bb0(%param_as_smt_var: !transform.param<i8>):
71+
smt.yield %param_as_smt_var : !transform.param<i8>
72+
}
73+
transform.yield
74+
}
75+
}
76+
77+
78+
// -----
79+
80+
// CHECK-LABEL: @mismatched_arg_type_bool
81+
module attributes {transform.with_named_sequence} {
82+
transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) {
5283
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
53-
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
84+
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1}}
5485
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
5586
^bb0(%param_as_smt_var: !smt.bool):
5687
smt.yield %param_as_smt_var : !smt.bool
@@ -61,9 +92,9 @@ module attributes {transform.with_named_sequence} {
6192

6293
// -----
6394

64-
// CHECK-LABEL: @mismatched_type_bitvector
95+
// CHECK-LABEL: @mismatched_arg_type_bitvector
6596
module attributes {transform.with_named_sequence} {
66-
transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
97+
transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
6798
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
6899
// expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
69100
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
@@ -73,3 +104,33 @@ module attributes {transform.with_named_sequence} {
73104
transform.yield
74105
}
75106
}
107+
108+
// -----
109+
110+
// CHECK-LABEL: @mismatched_result_type_bool
111+
module attributes {transform.with_named_sequence} {
112+
transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) {
113+
%param_as_param = transform.param.constant 1 -> !transform.param<i1>
114+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> (!transform.param<i64>) {
115+
^bb0(%param_as_smt_var: !smt.bool):
116+
// expected-error@below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param<i64>') is not wrapping i1}}
117+
smt.yield %param_as_smt_var : !smt.bool
118+
}
119+
transform.yield
120+
}
121+
}
122+
123+
// -----
124+
125+
// CHECK-LABEL: @mismatched_result_type_bitvector
126+
module attributes {transform.with_named_sequence} {
127+
transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
128+
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
129+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i64>) {
130+
^bb0(%param_as_smt_var: !smt.bv<8>):
131+
// expected-error@below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
132+
smt.yield %param_as_smt_var : !smt.bv<8>
133+
}
134+
transform.yield
135+
}
136+
}

0 commit comments

Comments
 (0)