Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 72 additions & 36 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class TypeConverter {
};

/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms(where `T` is a class derived from `Type`:
/// to any of the following forms (where `T` is a class derived from `Type`):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
Expand All @@ -151,15 +152,7 @@ class TypeConverter {
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
/// ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
/// "call stack" of the recursive conversion: it contains the list of
/// types currently being converted, with the current type being the
/// last one. If it is present more than once in the list, the
/// conversion concerns a recursive type.
///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
Expand All @@ -178,6 +171,9 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
/// parameter, which is the original type of the SSA value.

/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
Expand All @@ -203,11 +199,22 @@ class TypeConverter {

/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
///
/// Note: For target materializations, users can optionally take the original
/// type. This type may be different from the type of the input. For example,
/// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
/// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
/// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
/// that "P2" determines that the legalized type of "t1" is "t3", which may
/// be different from "t2". In this example, the target materialization
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
}

/// Register a conversion function for attributes within types. Type
Expand Down Expand Up @@ -303,21 +310,12 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
Type resultType,
ValueRange inputs) const {
return materializeConversion(argumentMaterializations, builder, loc,
resultType, inputs);
}
Type resultType, ValueRange inputs) const;
Value materializeSourceConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) const {
return materializeConversion(sourceMaterializations, builder, loc,
resultType, inputs);
}
Type resultType, ValueRange inputs) const;
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) const {
return materializeConversion(targetMaterializations, builder, loc,
resultType, inputs);
}
Type resultType, ValueRange inputs,
Type originalType = {}) const;

/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
Expand All @@ -333,21 +331,23 @@ class TypeConverter {
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &)>;

/// The signature of the callback used to materialize a conversion.
/// The signature of the callback used to materialize a source/argument
/// conversion.
///
/// Arguments: builder, result type, inputs, location
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;

/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location, Type)>;

/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
std::function<AttributeConversionResult(Type, Attribute)>;

/// Attempt to materialize a conversion using one of the provided
/// materialization functions.
Value
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType,
ValueRange inputs) const;

/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `std::optional<Type>(T)`
Expand Down Expand Up @@ -388,9 +388,10 @@ class TypeConverter {
cachedMultiConversions.clear();
}

/// Generate a wrapper for the given materialization callback. The callback
/// may take any subclass of `Type` and the wrapper will check for the target
/// type to be of the expected class before calling the callback.
/// Generate a wrapper for the given argument/source materialization
/// callback. The callback may take any subclass of `Type` and the
/// wrapper will check for the target type to be of the expected class
/// before calling the callback.
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
Expand All @@ -402,6 +403,41 @@ class TypeConverter {
};
}

/// Generate a wrapper for the given target materialization callback.
/// The callback may take any subclass of `Type` and the wrapper will check
/// for the target type to be of the expected class before calling the
/// callback.
///
/// With callback of form:
/// `Value(OpBuilder &, T, ValueRange, Location, Type)`
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc, Type originalType) -> std::optional<Value> {
if (T derivedType = dyn_cast<T>(resultType))
return callback(builder, derivedType, inputs, loc, originalType);
return std::nullopt;
};
}
/// With callback of form:
/// `Value(OpBuilder &, T, ValueRange, Location)`
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
Type originalType) -> std::optional<Value> {
return callback(builder, resultType, inputs, loc);
});
}

/// Generate a wrapper for the given memory space conversion callback. The
/// callback may take any subclass of `Attribute` and the wrapper will check
/// for the target attribute to be of the expected class before calling the
Expand Down Expand Up @@ -434,7 +470,7 @@ class TypeConverter {
/// The list of registered materialization functions.
SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;

/// The list of registered type attribute conversion functions.
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
Expand Down
83 changes: 64 additions & 19 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,10 @@ enum MaterializationKind {
/// conversion.
class UnresolvedMaterializationRewrite : public OperationRewrite {
public:
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target);
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op,
const TypeConverter *converter,
MaterializationKind kind, Type originalType);

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
Expand All @@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}

/// Return the original type of the SSA value.
Type getOriginalType() const { return originalType; }

private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;

/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;
};
} // namespace

Expand Down Expand Up @@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value buildUnresolvedMaterialization(MaterializationKind kind,
OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
Type originalType,
const TypeConverter *converter);

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {

UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind)
const TypeConverter *converter, MaterializationKind kind, Type originalType)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind) {
converterAndKind(converter, kind), originalType(originalType) {
assert(!originalType ||
kind == MaterializationKind::Target &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}

Expand Down Expand Up @@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
currentTypeConverter);
/*originalType=*/origType, currentTypeConverter);
mapping.map(newOperand, castValue);
newOperand = castValue;
}
Expand Down Expand Up @@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
/*outputType=*/origArgType, converter);
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
Expand All @@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/replArgs, origArgType, converter);
/*inputs=*/replArgs, /*outputType=*/origArgType,
/*originalType=*/Type(), converter);
mapping.map(origArg, argMat);

Type legalOutputType;
Expand All @@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat),
origArg.getLoc(), argMat, legalOutputType, converter);
origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
/*originalType=*/origArgType, converter);
mapping.map(argMat, targetMat);
}
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
Expand All @@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType, const TypeConverter *converter) {
ValueRange inputs, Type outputType, Type originalType,
const TypeConverter *converter) {
assert(!originalType ||
kind == MaterializationKind::Target &&
"original type is valid only for target materializations");

// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
return inputs.front();
Expand All @@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
originalType);
return convertOp.getResult(0);
}

Expand Down Expand Up @@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
newValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), currentTypeConverter);
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
}

// Remap, and check for any result type changes.
Expand Down Expand Up @@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
[[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
rewriter, op->getLoc(), outputType, inputOperands,
rewrite->getOriginalType());
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
Expand Down Expand Up @@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
MaterializationKind::Source, computeInsertPoint(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
converter);
/*originalType=*/Type(), converter);
rewriterImpl.mapping.map(originalValue, castValue);
inverseMapping[castValue].push_back(originalValue);
llvm::erase(inverseMapping[newValue], originalValue);
Expand Down Expand Up @@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return success();
}

Value TypeConverter::materializeConversion(
ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
Location loc, Type resultType, ValueRange inputs) const {
for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
Location loc,
Type resultType,
ValueRange inputs) const {
for (const MaterializationCallbackFn &fn :
llvm::reverse(argumentMaterializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
return nullptr;
}

Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs) const {
for (const MaterializationCallbackFn &fn :
llvm::reverse(sourceMaterializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
return nullptr;
}

Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
llvm::reverse(targetMaterializations))
if (std::optional<Value> result =
fn(builder, resultType, inputs, loc, originalType))
return *result;
return nullptr;
}

std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
Expand Down
Loading