Skip to content

Commit 2c438b0

Browse files
[mlir][Transforms] Dialect conversion: add originalType param to materialization
This commit adds an `originalType` parameter to all materialization functions. Without this parameter, target materializations are underspecified. Note: `originalType` is only needed for target materializations. For source/argument materializations, `originalType` always matches `outputType`. However, to keep the code base simple (i.e., reuse `MaterializationCallbackFn` for all three materializations), `originalType` is passed to all three materializations, even though it is only really needed for target materializations. `originalType` is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type. For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists. This commit also puts the `Location` parameter right after the `OpBuilder` parameter to be consistent with MLIR conventions. This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer. Note for LLVM integration: For all argument/source/target materialization functions, move the `Location` parameter to the second position and add a `Type originalType` parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match `inputs.front().getType()`. If the original type cannot be recovered (which is unlikely), pass `Type()`.
1 parent 9f24c14 commit 2c438b0

File tree

31 files changed

+238
-175
lines changed

31 files changed

+238
-175
lines changed

flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
173173
}
174174

175175
static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
176-
BoxProcType type,
176+
mlir::Location loc, BoxProcType type,
177177
mlir::ValueRange inputs,
178-
mlir::Location loc) {
178+
mlir::Type originalType) {
179179
assert(inputs.size() == 1);
180180
return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
181181
inputs[0]);

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,21 +170,27 @@ class TypeConverter {
170170

171171
/// All of the following materializations require function objects that are
172172
/// convertible to the following form:
173-
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
173+
/// `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
174174
/// where `T` is any subclass of `Type`. This function is responsible for
175175
/// creating an operation, using the OpBuilder and Location provided, that
176176
/// "casts" a range of values into a single value of the given type `T`. It
177177
/// must return a Value of the type `T` on success, an `std::nullopt` if
178178
/// it failed but other materialization can be attempted, and `nullptr` on
179179
/// unrecoverable failure. Materialization functions must be provided when a
180180
/// type conversion may persist after the conversion has finished.
181+
///
182+
/// The type that is provided as the 5-th argument is the original type of
183+
/// value. For more details, see the documentation below.
181184

182185
/// This method registers a materialization that will be called when
183186
/// converting (potentially multiple) block arguments that were the result of
184187
/// a signature conversion of a single block argument, to a single SSA value
185188
/// with the old block argument type.
189+
///
190+
/// Note: The original type matches the result type `T` for argument
191+
/// materializations.
186192
template <typename FnT, typename T = typename llvm::function_traits<
187-
std::decay_t<FnT>>::template arg_t<1>>
193+
std::decay_t<FnT>>::template arg_t<2>>
188194
void addArgumentMaterialization(FnT &&callback) {
189195
argumentMaterializations.emplace_back(
190196
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,17 +200,31 @@ class TypeConverter {
194200
/// converting a legal replacement value back to an illegal source type.
195201
/// This is used when some uses of the original, illegal value must persist
196202
/// beyond the main conversion.
203+
///
204+
/// Note: The original type matches the result type `T` for source
205+
/// materializations.
197206
template <typename FnT, typename T = typename llvm::function_traits<
198-
std::decay_t<FnT>>::template arg_t<1>>
207+
std::decay_t<FnT>>::template arg_t<2>>
199208
void addSourceMaterialization(FnT &&callback) {
200209
sourceMaterializations.emplace_back(
201210
wrapMaterialization<T>(std::forward<FnT>(callback)));
202211
}
203212

204213
/// This method registers a materialization that will be called when
205214
/// converting an illegal (source) value to a legal (target) type.
215+
///
216+
/// Note: For target materializations, the original type can be
217+
/// different from the type of the input. For example, let's assume that a
218+
/// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
219+
/// (type "t2"). Then a different conversion pattern "P2" matches an op that
220+
/// has "v1" as an operand. Let's furthermore assume that "P2" determines
221+
/// that the legalized type of "t1" is "t3", which may be different from
222+
/// "t2". In this example, the target materialization callback will be
223+
/// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
224+
/// that the original type "t1" cannot be recovered from just "t3" and "v2";
225+
/// that's why the originalType parameter exists.
206226
template <typename FnT, typename T = typename llvm::function_traits<
207-
std::decay_t<FnT>>::template arg_t<1>>
227+
std::decay_t<FnT>>::template arg_t<2>>
208228
void addTargetMaterialization(FnT &&callback) {
209229
targetMaterializations.emplace_back(
210230
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
303323
/// `add*Materialization` for more information on the context for these
304324
/// methods.
305325
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
306-
Type resultType,
307-
ValueRange inputs) const {
326+
Type resultType, ValueRange inputs,
327+
Type originalType) const {
308328
return materializeConversion(argumentMaterializations, builder, loc,
309-
resultType, inputs);
329+
resultType, inputs, originalType);
310330
}
311331
Value materializeSourceConversion(OpBuilder &builder, Location loc,
312-
Type resultType, ValueRange inputs) const {
332+
Type resultType, ValueRange inputs,
333+
Type originalType) const {
313334
return materializeConversion(sourceMaterializations, builder, loc,
314-
resultType, inputs);
335+
resultType, inputs, originalType);
315336
}
316337
Value materializeTargetConversion(OpBuilder &builder, Location loc,
317-
Type resultType, ValueRange inputs) const {
338+
Type resultType, ValueRange inputs,
339+
Type originalType) const {
318340
return materializeConversion(targetMaterializations, builder, loc,
319-
resultType, inputs);
341+
resultType, inputs, originalType);
320342
}
321343

322344
/// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
334356
Type, SmallVectorImpl<Type> &)>;
335357

336358
/// The signature of the callback used to materialize a conversion.
359+
///
360+
/// Arguments: builder, location, result type, inputs, original type
337361
using MaterializationCallbackFn = std::function<std::optional<Value>(
338-
OpBuilder &, Type, ValueRange, Location)>;
362+
OpBuilder &, Location, Type, ValueRange, Type)>;
339363

340364
/// The signature of the callback used to convert a type attribute.
341365
using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
346370
Value
347371
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
348372
OpBuilder &builder, Location loc, Type resultType,
349-
ValueRange inputs) const;
373+
ValueRange inputs, Type originalType) const;
350374

351375
/// Generate a wrapper for the given callback. This allows for accepting
352376
/// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
394418
template <typename T, typename FnT>
395419
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
396420
return [callback = std::forward<FnT>(callback)](
397-
OpBuilder &builder, Type resultType, ValueRange inputs,
398-
Location loc) -> std::optional<Value> {
421+
OpBuilder &builder, Location loc, Type resultType,
422+
ValueRange inputs, Type originalType) -> std::optional<Value> {
399423
if (T derivedType = dyn_cast<T>(resultType))
400-
return callback(builder, derivedType, inputs, loc);
424+
return callback(builder, loc, derivedType, inputs, originalType);
401425
return std::nullopt;
402426
};
403427
}

mlir/include/mlir/Transforms/OneToNTypeConversion.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
4444
/// materializations for 1:N type conversions, which materialize one value in
4545
/// a source type as N values in target types.
4646
using OneToNMaterializationCallbackFn =
47-
std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48-
Value, Location)>;
47+
std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
48+
TypeRange, Value, Type)>;
4949

5050
/// Creates the mapping of the given range of original types to target types
5151
/// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
6363
/// returns `std::nullopt`.
6464
std::optional<SmallVector<Value>>
6565
materializeTargetConversion(OpBuilder &builder, Location loc,
66-
TypeRange resultTypes, Value input) const;
66+
TypeRange resultTypes, Value input,
67+
Type originalType) const;
6768

6869
/// Adds a 1:N target materialization to the converter. Such materializations
6970
/// build IR that converts N values with target types into 1 value of the

mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
281281

282282
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
283283
// in patterns for other dialects.
284-
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285-
ValueRange inputs, Location loc) {
284+
auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
285+
ValueRange inputs, Type originalType) {
286286
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287287
return std::optional<Value>(cast.getResult(0));
288288
};

mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
4646
Location loc = arg.getLoc();
4747
Value newArg = block.insertArgument(argNum, newTy, loc);
4848
Value convertedValue = converter.materializeSourceConversion(
49-
builder, op->getLoc(), ty, newArg);
49+
builder, op->getLoc(), ty, newArg, ty);
5050
if (!convertedValue) {
5151
return rewriter.notifyMatchFailure(
5252
op, llvm::formatv("failed to cast new argument {0} to type {1})",

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
159159
// insert a target materialization from the original block argument type to
160160
// a legal type.
161161
addArgumentMaterialization(
162-
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
163-
Location loc) -> std::optional<Value> {
162+
[&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
163+
ValueRange inputs, Type originalType) -> std::optional<Value> {
164164
if (inputs.size() == 1) {
165165
// Bare pointers are not supported for unranked memrefs because a
166166
// memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
174174
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175175
.getResult(0);
176176
});
177-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178-
ValueRange inputs,
179-
Location loc) -> std::optional<Value> {
177+
addArgumentMaterialization([&](OpBuilder &builder, Location loc,
178+
MemRefType resultType, ValueRange inputs,
179+
Type originalType) -> std::optional<Value> {
180180
Value desc;
181181
if (inputs.size() == 1) {
182182
// This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
201201
});
202202
// Add generic source and target materializations to handle cases where
203203
// non-LLVM types persist after an LLVM conversion.
204-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
205-
ValueRange inputs,
206-
Location loc) -> std::optional<Value> {
204+
addSourceMaterialization([&](OpBuilder &builder, Location loc,
205+
Type resultType, ValueRange inputs,
206+
Type originalType) -> std::optional<Value> {
207207
if (inputs.size() != 1)
208208
return std::nullopt;
209209

210210
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211211
.getResult(0);
212212
});
213-
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs,
215-
Location loc) -> std::optional<Value> {
213+
addTargetMaterialization([&](OpBuilder &builder, Location loc,
214+
Type resultType, ValueRange inputs,
215+
Type originalType) -> std::optional<Value> {
216216
if (inputs.size() != 1)
217217
return std::nullopt;
218218

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
11851185
Type indexType = getIndexType();
11861186
if (dimSize.getType() != indexType)
11871187
dimSize = typeConverter->materializeTargetConversion(
1188-
rewriter, loc, indexType, dimSize);
1188+
rewriter, loc, indexType, dimSize, dimSize.getType());
11891189
assert(dimSize && "Invalid memref element type");
11901190
}
11911191

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
9797
// All other types legal
9898
return type;
9999
});
100-
converter.addTargetMaterialization(
101-
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
102-
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
103-
extFOp.setFastmath(arith::FastMathFlags::contract);
104-
return extFOp;
105-
});
100+
converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
101+
ValueRange input, Type originalType) {
102+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
103+
extFOp.setFastmath(arith::FastMathFlags::contract);
104+
return extFOp;
105+
});
106106
}
107107

108108
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
4242
// BufferizeTypeConverter
4343
//===----------------------------------------------------------------------===//
4444

45-
static Value materializeToTensor(OpBuilder &builder, TensorType type,
46-
ValueRange inputs, Location loc) {
45+
static Value materializeToTensor(OpBuilder &builder, Location loc,
46+
TensorType type, ValueRange inputs,
47+
Type originalType) {
4748
assert(inputs.size() == 1);
4849
assert(isa<BaseMemRefType>(inputs[0].getType()));
4950
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6364
});
6465
addArgumentMaterialization(materializeToTensor);
6566
addSourceMaterialization(materializeToTensor);
66-
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
67-
ValueRange inputs, Location loc) -> Value {
67+
addTargetMaterialization([](OpBuilder &builder, Location loc,
68+
BaseMemRefType type, ValueRange inputs,
69+
Type originalType) -> Value {
6870
assert(inputs.size() == 1 && "expected exactly one input");
6971

7072
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {

mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ using namespace mlir;
1717
namespace {
1818

1919
std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
20-
Type resultType,
20+
Location loc, Type resultType,
2121
ValueRange inputs,
22-
Location loc) {
22+
Type originalType) {
2323
if (inputs.size() != 1)
2424
return std::nullopt;
2525

0 commit comments

Comments
 (0)