1818#include " llvm/ADT/MapVector.h"
1919#include " llvm/ADT/StringMap.h"
2020#include < type_traits>
21+ #include < variant>
2122
2223namespace mlir {
2324
@@ -139,7 +140,8 @@ class TypeConverter {
139140 };
140141
141142 // / Register a conversion function. A conversion function must be convertible
142- // / to any of the following forms (where `T` is a class derived from `Type`):
143+ // / to any of the following forms (where `T` is `Value` or a class derived
144+ // / from `Type`, including `Type` itself):
143145 // /
144146 // / * std::optional<Type>(T)
145147 // / - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
154156 // / `std::nullopt` is returned, the converter is allowed to try another
155157 // / conversion function to perform the conversion.
156158 // /
159+ // / Conversion functions that accept `Value` as the first argument are
160+ // / context-aware. I.e., they can take into account IR when converting the
161+ // / type of the given value. Context-unaware conversion functions accept
162+ // / `Type` or a derived class as the first argument.
163+ // /
164+ // / Note: Context-unaware conversions are cached, but context-aware
165+ // / conversions are not.
166+ // /
157167 // / Note: When attempting to convert a type, e.g. via 'convertType', the
158168 // / mostly recently added conversions will be invoked first.
159169 template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
241251 wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
242252 }
243253
244- // / Convert the given type. This function should return failure if no valid
254+ // / Convert the given type. This function returns failure if no valid
245255 // / conversion exists, success otherwise. If the new set of types is empty,
246256 // / the type is removed and any usages of the existing value are expected to
247257 // / be removed during conversion.
258+ // /
259+ // / Note: This overload invokes only context-unaware type conversion
260+ // / functions. Users should call the other overload if possible.
248261 LogicalResult convertType (Type t, SmallVectorImpl<Type> &results) const ;
249262
263+ // / Convert the type of the given value. This function returns failure if no
264+ // / valid conversion exists, success otherwise. If the new set of types is
265+ // / empty, the type is removed and any usages of the existing value are
266+ // / expected to be removed during conversion.
267+ // /
268+ // / Note: This overload invokes both context-aware and context-unaware type
269+ // / conversion functions.
270+ LogicalResult convertType (Value v, SmallVectorImpl<Type> &results) const ;
271+
250272 // / This hook simplifies defining 1-1 type conversions. This function returns
251273 // / the type to convert to on success, and a null type on failure.
252274 Type convertType (Type t) const ;
275+ Type convertType (Value v) const ;
253276
254277 // / Attempts a 1-1 type conversion, expecting the result type to be
255278 // / `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
258281 TargetType convertType (Type t) const {
259282 return dyn_cast_or_null<TargetType>(convertType (t));
260283 }
284+ template <typename TargetType>
285+ TargetType convertType (Value v) const {
286+ return dyn_cast_or_null<TargetType>(convertType (v));
287+ }
261288
262- // / Convert the given set of types, filling 'results' as necessary. This
263- // / returns failure if the conversion of any of the types fails, success
289+ // / Convert the given types, filling 'results' as necessary. This returns
290+ // / " failure" if the conversion of any of the types fails, " success"
264291 // / otherwise.
265292 LogicalResult convertTypes (TypeRange types,
266293 SmallVectorImpl<Type> &results) const ;
267294
295+ // / Convert the types of the given values, filling 'results' as necessary.
296+ // / This returns "failure" if the conversion of any of the types fails,
297+ // / "success" otherwise.
298+ LogicalResult convertTypes (ValueRange values,
299+ SmallVectorImpl<Type> &results) const ;
300+
268301 // / Return true if the given type is legal for this type converter, i.e. the
269302 // / type converts to itself.
270303 bool isLegal (Type type) const ;
@@ -328,7 +361,7 @@ class TypeConverter {
328361 // / types is empty, the type is removed and any usages of the existing value
329362 // / are expected to be removed during conversion.
330363 using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
331- Type, SmallVectorImpl<Type> &)>;
364+ std::variant< Type, Value> , SmallVectorImpl<Type> &)>;
332365
333366 // / The signature of the callback used to materialize a source conversion.
334367 // /
@@ -348,13 +381,14 @@ class TypeConverter {
348381
349382 // / Generate a wrapper for the given callback. This allows for accepting
350383 // / different callback forms, that all compose into a single version.
351- // / With callback of form: `std::optional<Type>(T)`
384+ // / With callback of form: `std::optional<Type>(T)`, where `T` can be a
385+ // / `Value` or a `Type` (or a class derived from `Type`).
352386 template <typename T, typename FnT>
353387 std::enable_if_t <std::is_invocable_v<FnT, T>, ConversionCallbackFn>
354- wrapCallback (FnT &&callback) const {
388+ wrapCallback (FnT &&callback) {
355389 return wrapCallback<T>([callback = std::forward<FnT>(callback)](
356- T type , SmallVectorImpl<Type> &results) {
357- if (std::optional<Type> resultOpt = callback (type )) {
390+ T typeOrValue , SmallVectorImpl<Type> &results) {
391+ if (std::optional<Type> resultOpt = callback (typeOrValue )) {
358392 bool wasSuccess = static_cast <bool >(*resultOpt);
359393 if (wasSuccess)
360394 results.push_back (*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
364398 });
365399 }
366400 // / With callback of form: `std::optional<LogicalResult>(
367- // / T, SmallVectorImpl<Type> &, ArrayRef<Type>)` .
401+ // / T, SmallVectorImpl<Type> &)`, where `T` is a type .
368402 template <typename T, typename FnT>
369- std::enable_if_t <std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
403+ std::enable_if_t <std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
404+ std::is_base_of_v<Type, T>,
370405 ConversionCallbackFn>
371406 wrapCallback (FnT &&callback) const {
372407 return [callback = std::forward<FnT>(callback)](
373- Type type,
408+ std::variant< Type, Value> type,
374409 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
375- T derivedType = dyn_cast<T>(type);
410+ T derivedType;
411+ if (Type *t = std::get_if<Type>(&type)) {
412+ derivedType = dyn_cast<T>(*t);
413+ } else if (Value *v = std::get_if<Value>(&type)) {
414+ derivedType = dyn_cast<T>(v->getType ());
415+ } else {
416+ llvm_unreachable (" unexpected variant" );
417+ }
376418 if (!derivedType)
377419 return std::nullopt ;
378420 return callback (derivedType, results);
379421 };
380422 }
423+ // / With callback of form: `std::optional<LogicalResult>(
424+ // / T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
425+ template <typename T, typename FnT>
426+ std::enable_if_t <std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
427+ std::is_same_v<T, Value>,
428+ ConversionCallbackFn>
429+ wrapCallback (FnT &&callback) {
430+ hasContextAwareTypeConversions = true ;
431+ return [callback = std::forward<FnT>(callback)](
432+ std::variant<Type, Value> type,
433+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
434+ if (Type *t = std::get_if<Type>(&type)) {
435+ // Context-aware type conversion was called with a type.
436+ return std::nullopt ;
437+ } else if (Value *v = std::get_if<Value>(&type)) {
438+ return callback (*v, results);
439+ }
440+ llvm_unreachable (" unexpected variant" );
441+ return std::nullopt ;
442+ };
443+ }
381444
382445 // / Register a type conversion.
383446 void registerConversion (ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
504567 mutable DenseMap<Type, SmallVector<Type, 2 >> cachedMultiConversions;
505568 // / A mutex used for cache access
506569 mutable llvm::sys::SmartRWMutex<true > cacheMutex;
570+ // / Whether the type converter has context-aware type conversions. I.e.,
571+ // / conversion rules that depend on the SSA value instead of just the type.
572+ // / Type conversion caching is deactivated when there are context-aware
573+ // / conversions because the type converter may return different results for
574+ // / the same input type.
575+ bool hasContextAwareTypeConversions = false ;
507576};
508577
509578// ===----------------------------------------------------------------------===//
0 commit comments