Skip to content

Commit 80bc49f

Browse files
[mlir][Transforms] Dialect conversion: add originalType param to materializations v2
1 parent 9f24c14 commit 80bc49f

File tree

2 files changed

+136
-55
lines changed

2 files changed

+136
-55
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class TypeConverter {
138138
};
139139

140140
/// Register a conversion function. A conversion function must be convertible
141-
/// to any of the following forms(where `T` is a class derived from `Type`:
141+
/// to any of the following forms (where `T` is a class derived from `Type`):
142+
///
142143
/// * std::optional<Type>(T)
143144
/// - This form represents a 1-1 type conversion. It should return nullptr
144145
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
@@ -151,15 +152,7 @@ class TypeConverter {
151152
/// existing value are expected to be removed during conversion. If
152153
/// `std::nullopt` is returned, the converter is allowed to try another
153154
/// conversion function to perform the conversion.
154-
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
155-
/// ArrayRef<Type>)
156-
/// - This form represents a 1-N type conversion supporting recursive
157-
/// types. The first two arguments and the return value are the same as
158-
/// for the regular 1-N form. The third argument is contains is the
159-
/// "call stack" of the recursive conversion: it contains the list of
160-
/// types currently being converted, with the current type being the
161-
/// last one. If it is present more than once in the list, the
162-
/// conversion concerns a recursive type.
155+
///
163156
/// Note: When attempting to convert a type, e.g. via 'convertType', the
164157
/// mostly recently added conversions will be invoked first.
165158
template <typename FnT, typename T = typename llvm::function_traits<
@@ -178,6 +171,9 @@ class TypeConverter {
178171
/// it failed but other materialization can be attempted, and `nullptr` on
179172
/// unrecoverable failure. Materialization functions must be provided when a
180173
/// type conversion may persist after the conversion has finished.
174+
///
175+
/// Note: Target materializations may optionally accept an additional Type
176+
/// parameter, which is the original type of the SSA value.
181177

182178
/// This method registers a materialization that will be called when
183179
/// converting (potentially multiple) block arguments that were the result of
@@ -203,11 +199,22 @@ class TypeConverter {
203199

204200
/// This method registers a materialization that will be called when
205201
/// converting an illegal (source) value to a legal (target) type.
202+
///
203+
/// Note: For target materializations, users can optionally take the original
204+
/// type. This type may be different from the type of the input. For example,
205+
/// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
206+
/// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
207+
/// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
208+
/// that "P2" determines that the legalized type of "t1" is "t3", which may
209+
/// be different from "t2". In this example, the target materialization
210+
/// will be invoked with: outputType = "t3", inputs = "v2",
211+
// originalType = "t1". Note that the original type "t1" cannot be recovered
212+
/// from just "t3" and "v2"; that's why the originalType parameter exists.
206213
template <typename FnT, typename T = typename llvm::function_traits<
207214
std::decay_t<FnT>>::template arg_t<1>>
208215
void addTargetMaterialization(FnT &&callback) {
209216
targetMaterializations.emplace_back(
210-
wrapMaterialization<T>(std::forward<FnT>(callback)));
217+
wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
211218
}
212219

213220
/// Register a conversion function for attributes within types. Type
@@ -303,21 +310,12 @@ class TypeConverter {
303310
/// `add*Materialization` for more information on the context for these
304311
/// methods.
305312
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
306-
Type resultType,
307-
ValueRange inputs) const {
308-
return materializeConversion(argumentMaterializations, builder, loc,
309-
resultType, inputs);
310-
}
313+
Type resultType, ValueRange inputs) const;
311314
Value materializeSourceConversion(OpBuilder &builder, Location loc,
312-
Type resultType, ValueRange inputs) const {
313-
return materializeConversion(sourceMaterializations, builder, loc,
314-
resultType, inputs);
315-
}
315+
Type resultType, ValueRange inputs) const;
316316
Value materializeTargetConversion(OpBuilder &builder, Location loc,
317-
Type resultType, ValueRange inputs) const {
318-
return materializeConversion(targetMaterializations, builder, loc,
319-
resultType, inputs);
320-
}
317+
Type resultType, ValueRange inputs,
318+
Type originalType = {}) const;
321319

322320
/// Convert an attribute present `attr` from within the type `type` using
323321
/// the registered conversion functions. If no applicable conversion has been
@@ -333,21 +331,23 @@ class TypeConverter {
333331
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
334332
Type, SmallVectorImpl<Type> &)>;
335333

336-
/// The signature of the callback used to materialize a conversion.
334+
/// The signature of the callback used to materialize a source/argument
335+
/// conversion.
336+
///
337+
/// Arguments: builder, result type, inputs, location
337338
using MaterializationCallbackFn = std::function<std::optional<Value>(
338339
OpBuilder &, Type, ValueRange, Location)>;
339340

341+
/// The signature of the callback used to materialize a target conversion.
342+
///
343+
/// Arguments: builder, result type, inputs, location, original type
344+
using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
345+
OpBuilder &, Type, ValueRange, Location, Type)>;
346+
340347
/// The signature of the callback used to convert a type attribute.
341348
using TypeAttributeConversionCallbackFn =
342349
std::function<AttributeConversionResult(Type, Attribute)>;
343350

344-
/// Attempt to materialize a conversion using one of the provided
345-
/// materialization functions.
346-
Value
347-
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
348-
OpBuilder &builder, Location loc, Type resultType,
349-
ValueRange inputs) const;
350-
351351
/// Generate a wrapper for the given callback. This allows for accepting
352352
/// different callback forms, that all compose into a single version.
353353
/// With callback of form: `std::optional<Type>(T)`
@@ -388,9 +388,10 @@ class TypeConverter {
388388
cachedMultiConversions.clear();
389389
}
390390

391-
/// Generate a wrapper for the given materialization callback. The callback
392-
/// may take any subclass of `Type` and the wrapper will check for the target
393-
/// type to be of the expected class before calling the callback.
391+
/// Generate a wrapper for the given argument/source materialization
392+
/// callback. The callback may take any subclass of `Type` and the
393+
/// wrapper will check for the target type to be of the expected class
394+
/// before calling the callback.
394395
template <typename T, typename FnT>
395396
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
396397
return [callback = std::forward<FnT>(callback)](
@@ -402,6 +403,41 @@ class TypeConverter {
402403
};
403404
}
404405

406+
/// Generate a wrapper for the given target materialization callback.
407+
/// The callback may take any subclass of `Type` and the wrapper will check
408+
/// for the target type to be of the expected class before calling the
409+
/// callback.
410+
///
411+
/// With callback of form:
412+
/// `Value(OpBuilder &, T, ValueRange, Location, Type)`
413+
template <typename T, typename FnT>
414+
std::enable_if_t<
415+
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
416+
TargetMaterializationCallbackFn>
417+
wrapTargetMaterialization(FnT &&callback) const {
418+
return [callback = std::forward<FnT>(callback)](
419+
OpBuilder &builder, Type resultType, ValueRange inputs,
420+
Location loc, Type originalType) -> std::optional<Value> {
421+
if (T derivedType = dyn_cast<T>(resultType))
422+
return callback(builder, derivedType, inputs, loc, originalType);
423+
return std::nullopt;
424+
};
425+
}
426+
/// With callback of form:
427+
/// `Value(OpBuilder &, T, ValueRange, Location)`
428+
template <typename T, typename FnT>
429+
std::enable_if_t<
430+
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
431+
TargetMaterializationCallbackFn>
432+
wrapTargetMaterialization(FnT &&callback) const {
433+
return wrapTargetMaterialization<T>(
434+
[callback = std::forward<FnT>(callback)](
435+
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
436+
Type originalType) -> std::optional<Value> {
437+
return callback(builder, resultType, inputs, loc);
438+
});
439+
}
440+
405441
/// Generate a wrapper for the given memory space conversion callback. The
406442
/// callback may take any subclass of `Attribute` and the wrapper will check
407443
/// for the target attribute to be of the expected class before calling the
@@ -434,7 +470,7 @@ class TypeConverter {
434470
/// The list of registered materialization functions.
435471
SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
436472
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
437-
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
473+
SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
438474

439475
/// The list of registered type attribute conversion functions.
440476
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,10 @@ enum MaterializationKind {
683683
/// conversion.
684684
class UnresolvedMaterializationRewrite : public OperationRewrite {
685685
public:
686-
UnresolvedMaterializationRewrite(
687-
ConversionPatternRewriterImpl &rewriterImpl,
688-
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
689-
MaterializationKind kind = MaterializationKind::Target);
686+
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
687+
UnrealizedConversionCastOp op,
688+
const TypeConverter *converter,
689+
MaterializationKind kind, Type originalType);
690690

691691
static bool classof(const IRRewrite *rewrite) {
692692
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
708708
return converterAndKind.getInt();
709709
}
710710

711+
/// Return the original type of the SSA value.
712+
Type getOriginalType() const { return originalType; }
713+
711714
private:
712715
/// The corresponding type converter to use when resolving this
713716
/// materialization, and the kind of this materialization.
714717
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
715718
converterAndKind;
719+
720+
/// The original type of the SSA value. Only used for target
721+
/// materializations.
722+
Type originalType;
716723
};
717724
} // namespace
718725

@@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
808815
Value buildUnresolvedMaterialization(MaterializationKind kind,
809816
OpBuilder::InsertPoint ip, Location loc,
810817
ValueRange inputs, Type outputType,
818+
Type originalType,
811819
const TypeConverter *converter);
812820

813821
//===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {
10341042

10351043
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
10361044
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1037-
const TypeConverter *converter, MaterializationKind kind)
1045+
const TypeConverter *converter, MaterializationKind kind, Type originalType)
10381046
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1039-
converterAndKind(converter, kind) {
1047+
converterAndKind(converter, kind), originalType(originalType) {
1048+
assert(!originalType ||
1049+
kind == MaterializationKind::Target &&
1050+
"original type is valid only for target materializations");
10401051
rewriterImpl.unresolvedMaterializations[op] = this;
10411052
}
10421053

@@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11391150
Value castValue = buildUnresolvedMaterialization(
11401151
MaterializationKind::Target, computeInsertPoint(newOperand),
11411152
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1142-
currentTypeConverter);
1153+
/*originalType=*/origType, currentTypeConverter);
11431154
mapping.map(newOperand, castValue);
11441155
newOperand = castValue;
11451156
}
@@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12551266
MaterializationKind::Source,
12561267
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
12571268
/*inputs=*/ValueRange(),
1258-
/*outputType=*/origArgType, converter);
1269+
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
12591270
mapping.map(origArg, repl);
12601271
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
12611272
continue;
@@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12801291
Value argMat = buildUnresolvedMaterialization(
12811292
MaterializationKind::Argument,
12821293
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1283-
/*inputs=*/replArgs, origArgType, converter);
1294+
/*inputs=*/replArgs, /*outputType=*/origArgType,
1295+
/*originalType=*/Type(), converter);
12841296
mapping.map(origArg, argMat);
12851297

12861298
Type legalOutputType;
@@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12991311
if (legalOutputType && legalOutputType != origArgType) {
13001312
Value targetMat = buildUnresolvedMaterialization(
13011313
MaterializationKind::Target, computeInsertPoint(argMat),
1302-
origArg.getLoc(), argMat, legalOutputType, converter);
1314+
origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
1315+
/*originalType=*/origArgType, converter);
13031316
mapping.map(argMat, targetMat);
13041317
}
13051318
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13221335
/// of input operands.
13231336
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13241337
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1325-
ValueRange inputs, Type outputType, const TypeConverter *converter) {
1338+
ValueRange inputs, Type outputType, Type originalType,
1339+
const TypeConverter *converter) {
1340+
assert(!originalType ||
1341+
kind == MaterializationKind::Target &&
1342+
"original type is valid only for target materializations");
1343+
13261344
// Avoid materializing an unnecessary cast.
13271345
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13281346
return inputs.front();
@@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13331351
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13341352
auto convertOp =
13351353
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1336-
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1354+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1355+
originalType);
13371356
return convertOp.getResult(0);
13381357
}
13391358

@@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13811400
newValue = buildUnresolvedMaterialization(
13821401
MaterializationKind::Source, computeInsertPoint(result),
13831402
result.getLoc(), /*inputs=*/ValueRange(),
1384-
/*outputType=*/result.getType(), currentTypeConverter);
1403+
/*outputType=*/result.getType(), /*originalType=*/Type(),
1404+
currentTypeConverter);
13851405
}
13861406

13871407
// Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
24082428
[[fallthrough]];
24092429
case MaterializationKind::Target:
24102430
newMaterialization = converter->materializeTargetConversion(
2411-
rewriter, op->getLoc(), outputType, inputOperands);
2431+
rewriter, op->getLoc(), outputType, inputOperands,
2432+
rewrite->getOriginalType());
24122433
break;
24132434
case MaterializationKind::Source:
24142435
newMaterialization = converter->materializeSourceConversion(
@@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25652586
MaterializationKind::Source, computeInsertPoint(newValue),
25662587
originalValue.getLoc(),
25672588
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2568-
converter);
2589+
/*originalType=*/Type(), converter);
25692590
rewriterImpl.mapping.map(originalValue, castValue);
25702591
inverseMapping[castValue].push_back(originalValue);
25712592
llvm::erase(inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
27872808
return success();
27882809
}
27892810

2790-
Value TypeConverter::materializeConversion(
2791-
ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
2792-
Location loc, Type resultType, ValueRange inputs) const {
2793-
for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
2811+
Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
2812+
Location loc,
2813+
Type resultType,
2814+
ValueRange inputs) const {
2815+
for (const MaterializationCallbackFn &fn :
2816+
llvm::reverse(argumentMaterializations))
2817+
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2818+
return *result;
2819+
return nullptr;
2820+
}
2821+
2822+
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
2823+
Location loc, Type resultType,
2824+
ValueRange inputs) const {
2825+
for (const MaterializationCallbackFn &fn :
2826+
llvm::reverse(sourceMaterializations))
27942827
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
27952828
return *result;
27962829
return nullptr;
27972830
}
27982831

2832+
Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
2833+
Location loc, Type resultType,
2834+
ValueRange inputs,
2835+
Type originalType) const {
2836+
for (const TargetMaterializationCallbackFn &fn :
2837+
llvm::reverse(targetMaterializations))
2838+
if (std::optional<Value> result =
2839+
fn(builder, resultType, inputs, loc, originalType))
2840+
return *result;
2841+
return nullptr;
2842+
}
2843+
27992844
std::optional<TypeConverter::SignatureConversion>
28002845
TypeConverter::convertBlockSignature(Block *block) const {
28012846
SignatureConversion conversion(block->getNumArguments());

0 commit comments

Comments
 (0)