Skip to content

Commit b22f94d

Browse files
authored
[MLIR] Enable caching of type conversion in the presence of context-aware conversion (#158072)
The current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones. This isn't NFC because if fixed a bug where we use to unconditionally cache when using the `convertType(Type t, ...` API, while now all APIs are aware of context-aware conversions.
1 parent 4ce74bf commit b22f94d

File tree

3 files changed

+66
-45
lines changed

3 files changed

+66
-45
lines changed

mlir/docs/DialectConversion.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a
285285
`Type`. A context-aware conversion function converts a `Value` into a type. The
286286
latter allows users to customize type conversion rules based on the IR.
287287

288-
Note: When there is at least one context-aware type conversion function, the
289-
result of type conversions can no longer be cached, which can increase
290-
compilation time. Use this feature with caution!
288+
Note: context-aware type conversion functions impact the ability of the
289+
framework to cache the conversion result. In the absence of a context-aware
290+
conversion, all context-free type conversions can be cached. Otherwise only the
291+
context-free conversions added after a context-aware type conversion can be
292+
cached (conversions are applied in reverse order).
293+
As such it is advised to add context-aware conversions as early as possible in
294+
the sequence of `addConversion` calls (so that they apply last).
291295

292296
A `materialization` describes how a list of values should be converted to a
293297
list of values with specific types. An important distinction from a

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ class TypeConverter {
433433
std::is_same_v<T, Value>,
434434
ConversionCallbackFn>
435435
wrapCallback(FnT &&callback) {
436-
hasContextAwareTypeConversions = true;
436+
contextAwareTypeConversionsIndex = conversions.size();
437437
return [callback = std::forward<FnT>(callback)](
438438
PointerUnion<Type, Value> typeOrValue,
439439
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -555,6 +555,10 @@ class TypeConverter {
555555
cachedMultiConversions.clear();
556556
}
557557

558+
/// Internal implementation of the type conversion.
559+
LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
560+
SmallVectorImpl<Type> &results) const;
561+
558562
/// The set of registered conversion functions.
559563
SmallVector<ConversionCallbackFn, 4> conversions;
560564

@@ -575,10 +579,13 @@ class TypeConverter {
575579
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
576580
/// Whether the type converter has context-aware type conversions. I.e.,
577581
/// conversion rules that depend on the SSA value instead of just the type.
578-
/// Type conversion caching is deactivated when there are context-aware
579-
/// conversions because the type converter may return different results for
580-
/// the same input type.
581-
bool hasContextAwareTypeConversions = false;
582+
/// We store here the index in the `conversions` vector of the last added
583+
/// context-aware conversion, if any. This is useful because we can't cache
584+
/// the result of type conversion happening after context-aware conversions,
585+
/// because the type converter may return different results for the same input
586+
/// type. This is why it is recommened to add context-aware conversions first,
587+
/// any context-free conversions after will benefit from caching.
588+
int contextAwareTypeConversionsIndex = -1;
582589
};
583590

584591
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput(
34063406
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
34073407
}
34083408

3409-
LogicalResult TypeConverter::convertType(Type t,
3410-
SmallVectorImpl<Type> &results) const {
3411-
assert(t && "expected non-null type");
3412-
3409+
/// Internal implementation of the type conversion.
3410+
/// This is used with either a Type or a Value as the first argument.
3411+
/// - we can cache the context-free conversions until the last registered
3412+
/// context-aware conversion.
3413+
/// - we can't cache the result of type conversion happening after context-aware
3414+
/// conversions, because the type converter may return different results for the
3415+
/// same input type.
3416+
LogicalResult
3417+
TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3418+
SmallVectorImpl<Type> &results) const {
3419+
assert(typeOrValue && "expected non-null type");
3420+
Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3421+
: cast<Type>(typeOrValue);
34133422
{
34143423
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
34153424
std::defer_lock);
@@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t,
34313440
// registered first.
34323441
size_t currentCount = results.size();
34333442

3443+
// We can cache the context-free conversions until the last registered
3444+
// context-aware conversion. But only if we're processing a Value right now.
3445+
auto isCacheable = [&](int index) {
3446+
int numberOfConversionsUntilContextAware =
3447+
conversions.size() - 1 - contextAwareTypeConversionsIndex;
3448+
return index < numberOfConversionsUntilContextAware;
3449+
};
3450+
34343451
std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
34353452
std::defer_lock);
34363453

3437-
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3438-
if (std::optional<LogicalResult> result = converter(t, results)) {
3439-
if (t.getContext()->isMultithreadingEnabled())
3440-
cacheWriteLock.lock();
3441-
if (!succeeded(*result)) {
3442-
assert(results.size() == currentCount &&
3443-
"failed type conversion should not change results");
3444-
cachedDirectConversions.try_emplace(t, nullptr);
3445-
return failure();
3446-
}
3447-
auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3448-
if (newTypes.size() == 1)
3449-
cachedDirectConversions.try_emplace(t, newTypes.front());
3450-
else
3451-
cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3454+
for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3455+
const ConversionCallbackFn &converter = indexedConverter.value();
3456+
std::optional<LogicalResult> result = converter(typeOrValue, results);
3457+
if (!result) {
3458+
assert(results.size() == currentCount &&
3459+
"failed type conversion should not change results");
3460+
continue;
3461+
}
3462+
if (!isCacheable(indexedConverter.index()))
34523463
return success();
3453-
} else {
3464+
if (t.getContext()->isMultithreadingEnabled())
3465+
cacheWriteLock.lock();
3466+
if (!succeeded(*result)) {
34543467
assert(results.size() == currentCount &&
34553468
"failed type conversion should not change results");
3469+
cachedDirectConversions.try_emplace(t, nullptr);
3470+
return failure();
34563471
}
3472+
auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3473+
if (newTypes.size() == 1)
3474+
cachedDirectConversions.try_emplace(t, newTypes.front());
3475+
else
3476+
cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3477+
return success();
34573478
}
34583479
return failure();
34593480
}
34603481

3461-
LogicalResult TypeConverter::convertType(Value v,
3482+
LogicalResult TypeConverter::convertType(Type t,
34623483
SmallVectorImpl<Type> &results) const {
3463-
assert(v && "expected non-null value");
3464-
3465-
// If this type converter does not have context-aware type conversions, call
3466-
// the type-based overload, which has caching.
3467-
if (!hasContextAwareTypeConversions)
3468-
return convertType(v.getType(), results);
3484+
return convertTypeImpl(t, results);
3485+
}
34693486

3470-
// Walk the added converters in reverse order to apply the most recently
3471-
// registered first.
3472-
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3473-
if (std::optional<LogicalResult> result = converter(v, results)) {
3474-
if (!succeeded(*result))
3475-
return failure();
3476-
return success();
3477-
}
3478-
}
3479-
return failure();
3487+
LogicalResult TypeConverter::convertType(Value v,
3488+
SmallVectorImpl<Type> &results) const {
3489+
return convertTypeImpl(v, results);
34803490
}
34813491

34823492
Type TypeConverter::convertType(Type t) const {

0 commit comments

Comments
 (0)