1515#include " mlir/Dialect/Transform/IR/TransformDialect.h"
1616#include " mlir/Dialect/Transform/IR/TransformInterfaces.h"
1717#include " mlir/Dialect/Transform/IR/TransformOps.h"
18+ #include " mlir/Transforms/DialectConversion.h"
1819
1920using namespace mlir ;
2021
@@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
3637 return success ();
3738}
3839
40+ // ===----------------------------------------------------------------------===//
41+ // CastAndCallOp
42+ // ===----------------------------------------------------------------------===//
43+
44+ DiagnosedSilenceableFailure
45+ transform::CastAndCallOp::apply (transform::TransformRewriter &rewriter,
46+ transform::TransformResults &results,
47+ transform::TransformState &state) {
48+ SmallVector<Value> inputs;
49+ if (getInputs ())
50+ llvm::append_range (inputs, state.getPayloadValues (getInputs ()));
51+
52+ SetVector<Value> outputs;
53+ if (getOutputs ()) {
54+ for (auto output : state.getPayloadValues (getOutputs ()))
55+ outputs.insert (output);
56+
57+ // Verify that the set of output values to be replaced is unique.
58+ if (outputs.size () !=
59+ llvm::range_size (state.getPayloadValues (getOutputs ()))) {
60+ return emitSilenceableFailure (getLoc ())
61+ << " cast and call output values must be unique" ;
62+ }
63+ }
64+
65+ // Get the insertion point for the call.
66+ auto insertionOps = state.getPayloadOps (getInsertionPoint ());
67+ if (!llvm::hasSingleElement (insertionOps)) {
68+ return emitSilenceableFailure (getLoc ())
69+ << " Only one op can be specified as an insertion point" ;
70+ }
71+ bool insertAfter = getInsertAfter ();
72+ Operation *insertionPoint = *insertionOps.begin ();
73+
74+ // Check that all inputs dominate the insertion point, and the insertion
75+ // point dominates all users of the outputs.
76+ DominanceInfo dom (insertionPoint);
77+ for (Value output : outputs) {
78+ for (Operation *user : output.getUsers ()) {
79+ // If we are inserting after the insertion point operation, the
80+ // insertion point operation must properly dominate the user. Otherwise
81+ // basic dominance is enough.
82+ bool doesDominate = insertAfter
83+ ? dom.properlyDominates (insertionPoint, user)
84+ : dom.dominates (insertionPoint, user);
85+ if (!doesDominate) {
86+ return emitDefiniteFailure ()
87+ << " User " << user << " is not dominated by insertion point "
88+ << insertionPoint;
89+ }
90+ }
91+ }
92+
93+ for (Value input : inputs) {
94+ // If we are inserting before the insertion point operation, the
95+ // input must properly dominate the insertion point operation. Otherwise
96+ // basic dominance is enough.
97+ bool doesDominate = insertAfter
98+ ? dom.dominates (input, insertionPoint)
99+ : dom.properlyDominates (input, insertionPoint);
100+ if (!doesDominate) {
101+ return emitDefiniteFailure ()
102+ << " input " << input << " does not dominate insertion point "
103+ << insertionPoint;
104+ }
105+ }
106+
107+ // Get the function to call. This can either be specified by symbol or as a
108+ // transform handle.
109+ func::FuncOp targetFunction = nullptr ;
110+ if (getFunctionName ()) {
111+ targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112+ insertionPoint, *getFunctionName ());
113+ if (!targetFunction) {
114+ return emitDefiniteFailure ()
115+ << " unresolved symbol " << *getFunctionName ();
116+ }
117+ } else if (getFunction ()) {
118+ auto payloadOps = state.getPayloadOps (getFunction ());
119+ if (!llvm::hasSingleElement (payloadOps)) {
120+ return emitDefiniteFailure () << " requires a single function to call" ;
121+ }
122+ targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin ());
123+ if (!targetFunction) {
124+ return emitDefiniteFailure () << " invalid non-function callee" ;
125+ }
126+ } else {
127+ llvm_unreachable (" Invalid CastAndCall op without a function to call" );
128+ return emitDefiniteFailure ();
129+ }
130+
131+ // Verify that the function argument and result lengths match the inputs and
132+ // outputs given to this op.
133+ if (targetFunction.getNumArguments () != inputs.size ()) {
134+ return emitSilenceableFailure (targetFunction.getLoc ())
135+ << " mismatch between number of function arguments "
136+ << targetFunction.getNumArguments () << " and number of inputs "
137+ << inputs.size ();
138+ }
139+ if (targetFunction.getNumResults () != outputs.size ()) {
140+ return emitSilenceableFailure (targetFunction.getLoc ())
141+ << " mismatch between number of function results "
142+ << targetFunction->getNumResults () << " and number of outputs "
143+ << outputs.size ();
144+ }
145+
146+ // Gather all specified converters.
147+ mlir::TypeConverter converter;
148+ if (!getRegion ().empty ()) {
149+ for (Operation &op : getRegion ().front ()) {
150+ cast<transform::TypeConverterBuilderOpInterface>(&op)
151+ .populateTypeMaterializations (converter);
152+ }
153+ }
154+
155+ if (insertAfter)
156+ rewriter.setInsertionPointAfter (insertionPoint);
157+ else
158+ rewriter.setInsertionPoint (insertionPoint);
159+
160+ for (auto [input, type] :
161+ llvm::zip_equal (inputs, targetFunction.getArgumentTypes ())) {
162+ if (input.getType () != type) {
163+ Value newInput = converter.materializeSourceConversion (
164+ rewriter, input.getLoc (), type, input);
165+ if (!newInput) {
166+ return emitDefiniteFailure () << " Failed to materialize conversion of "
167+ << input << " to type " << type;
168+ }
169+ input = newInput;
170+ }
171+ }
172+
173+ auto callOp = rewriter.create <func::CallOp>(insertionPoint->getLoc (),
174+ targetFunction, inputs);
175+
176+ // Cast the call results back to the expected types. If any conversions fail
177+ // this is a definite failure as the call has been constructed at this point.
178+ for (auto [output, newOutput] :
179+ llvm::zip_equal (outputs, callOp.getResults ())) {
180+ Value convertedOutput = newOutput;
181+ if (output.getType () != newOutput.getType ()) {
182+ convertedOutput = converter.materializeTargetConversion (
183+ rewriter, output.getLoc (), output.getType (), newOutput);
184+ if (!convertedOutput) {
185+ return emitDefiniteFailure ()
186+ << " Failed to materialize conversion of " << newOutput
187+ << " to type " << output.getType ();
188+ }
189+ }
190+ rewriter.replaceAllUsesExcept (output, convertedOutput, callOp);
191+ }
192+ results.set (cast<OpResult>(getResult ()), {callOp});
193+ return DiagnosedSilenceableFailure::success ();
194+ }
195+
196+ LogicalResult transform::CastAndCallOp::verify () {
197+ if (!getRegion ().empty ()) {
198+ for (Operation &op : getRegion ().front ()) {
199+ if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200+ InFlightDiagnostic diag = emitOpError ()
201+ << " expected children ops to implement "
202+ " TypeConverterBuilderOpInterface" ;
203+ diag.attachNote (op.getLoc ()) << " op without interface" ;
204+ return diag;
205+ }
206+ }
207+ }
208+ if (!getFunction () && !getFunctionName ()) {
209+ return emitOpError () << " expected a function handle or name to call" ;
210+ }
211+ if (getFunction () && getFunctionName ()) {
212+ return emitOpError () << " function handle and name are mutually exclusive" ;
213+ }
214+ return success ();
215+ }
216+
217+ void transform::CastAndCallOp::getEffects (
218+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219+ transform::onlyReadsHandle (getInsertionPoint (), effects);
220+ if (getInputs ())
221+ transform::onlyReadsHandle (getInputs (), effects);
222+ if (getOutputs ())
223+ transform::onlyReadsHandle (getOutputs (), effects);
224+ if (getFunction ())
225+ transform::onlyReadsHandle (getFunction (), effects);
226+ transform::producesHandle (getResult (), effects);
227+ transform::modifiesPayload (effects);
228+ }
229+
39230// ===----------------------------------------------------------------------===//
40231// Transform op registration
41232// ===----------------------------------------------------------------------===//
0 commit comments