Skip to content

Commit 7830e4d

Browse files
committed
[flang] improve DITypeAttr caching with recursive derived types
1 parent 7ec494e commit 7830e4d

File tree

4 files changed

+354
-111
lines changed

4 files changed

+354
-111
lines changed

flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp

Lines changed: 127 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ DebugTypeGenerator::DebugTypeGenerator(mlir::ModuleOp m,
4848
mlir::SymbolTable *symbolTable_,
4949
const mlir::DataLayout &dl)
5050
: module(m), symbolTable(symbolTable_), dataLayout{&dl},
51-
kindMapping(getKindMapping(m)), llvmTypeConverter(m, false, false, dl),
52-
derivedTypeDepth(0) {
51+
kindMapping(getKindMapping(m)), llvmTypeConverter(m, false, false, dl) {
5352
LLVM_DEBUG(llvm::dbgs() << "DITypeAttr generator\n");
5453

5554
mlir::MLIRContext *context = module.getContext();
@@ -272,31 +271,127 @@ DebugTypeGenerator::getFieldSizeAndAlign(mlir::Type fieldTy) {
272271
return std::pair{byteSize, byteAlign};
273272
}
274273

274+
mlir::LLVM::DITypeAttr DerivedTypeCache::lookup(mlir::Type type) {
275+
auto iter = typeCache.find(type);
276+
if (iter != typeCache.end()) {
277+
if (iter->second.first) {
278+
componentActiveRecursionLevels = iter->second.second;
279+
}
280+
return iter->second.first;
281+
}
282+
return nullptr;
283+
}
284+
285+
DerivedTypeCache::ActiveLevels
286+
DerivedTypeCache::startTranslating(mlir::Type type,
287+
mlir::LLVM::DITypeAttr placeHolder) {
288+
derivedTypeDepth++;
289+
if (!placeHolder)
290+
return {};
291+
typeCache[type] = std::pair<mlir::LLVM::DITypeAttr, ActiveLevels>(
292+
placeHolder, {derivedTypeDepth});
293+
return {};
294+
}
295+
296+
void DerivedTypeCache::preComponentVisitUpdate() {
297+
componentActiveRecursionLevels.clear();
298+
}
299+
300+
void DerivedTypeCache::postComponentVisitUpdate(
301+
ActiveLevels &activeRecursionLevels) {
302+
if (componentActiveRecursionLevels.empty())
303+
return;
304+
ActiveLevels oldLevels;
305+
oldLevels.swap(activeRecursionLevels);
306+
std::merge(componentActiveRecursionLevels.begin(),
307+
componentActiveRecursionLevels.end(), oldLevels.begin(),
308+
oldLevels.end(), std::back_inserter(activeRecursionLevels));
309+
}
310+
311+
void DerivedTypeCache::finalize(mlir::Type ty, mlir::LLVM::DITypeAttr attr,
312+
ActiveLevels &&activeRecursionLevels) {
313+
// If there is no nested recursion or if this type does not point to any type
314+
// nodes above it, it is safe to cache it indefinitely (it can be used in any
315+
// contexts).
316+
if (activeRecursionLevels.empty() ||
317+
(activeRecursionLevels[0] == derivedTypeDepth)) {
318+
typeCache[ty] = std::pair<mlir::LLVM::DITypeAttr, ActiveLevels>(attr, {});
319+
componentActiveRecursionLevels.clear();
320+
cleanUpCache(derivedTypeDepth);
321+
--derivedTypeDepth;
322+
return;
323+
}
324+
// Trim any recursion below the current type.
325+
if (activeRecursionLevels.back() >= derivedTypeDepth) {
326+
auto last = llvm::find_if(activeRecursionLevels, [&](std::int32_t depth) {
327+
return depth >= derivedTypeDepth;
328+
});
329+
if (last != activeRecursionLevels.end()) {
330+
activeRecursionLevels.erase(last, activeRecursionLevels.end());
331+
}
332+
}
333+
componentActiveRecursionLevels = std::move(activeRecursionLevels);
334+
typeCache[ty] = std::pair<mlir::LLVM::DITypeAttr, ActiveLevels>(
335+
attr, componentActiveRecursionLevels);
336+
cleanUpCache(derivedTypeDepth);
337+
if (!componentActiveRecursionLevels.empty())
338+
insertCacheCleanUp(ty, componentActiveRecursionLevels.back());
339+
--derivedTypeDepth;
340+
}
341+
342+
void DerivedTypeCache::insertCacheCleanUp(mlir::Type type, int32_t depth) {
343+
auto iter = llvm::find_if(cacheCleanupList,
344+
[&](const auto &x) { return x.second >= depth; });
345+
if (iter == cacheCleanupList.end()) {
346+
cacheCleanupList.emplace_back(
347+
std::pair<llvm::SmallVector<mlir::Type>, int32_t>({type}, depth));
348+
return;
349+
}
350+
if (iter->second == depth) {
351+
iter->first.push_back(type);
352+
return;
353+
}
354+
cacheCleanupList.insert(
355+
iter, std::pair<llvm::SmallVector<mlir::Type>, int32_t>({type}, depth));
356+
}
357+
358+
void DerivedTypeCache::cleanUpCache(int32_t depth) {
359+
if (cacheCleanupList.empty())
360+
return;
361+
// cleanups are done in the post actions when visiting a derived type
362+
// tree. So if there is a clean-up for the current depth, it has to be
363+
// the last one (deeper ones must have been done already).
364+
if (cacheCleanupList.back().second == depth) {
365+
for (mlir::Type type : cacheCleanupList.back().first)
366+
typeCache[type].first = nullptr;
367+
cacheCleanupList.pop_back_n(1);
368+
}
369+
}
370+
275371
mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
276372
fir::RecordType Ty, mlir::LLVM::DIFileAttr fileAttr,
277373
mlir::LLVM::DIScopeAttr scope, fir::cg::XDeclareOp declOp) {
278-
// Check if this type has already been converted.
279-
auto iter = typeCache.find(Ty);
280-
if (iter != typeCache.end())
281-
return iter->second;
282374

283-
bool canCacheThisType = true;
284-
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
375+
if (mlir::LLVM::DITypeAttr attr = derivedTypeCache.lookup(Ty))
376+
return attr;
377+
285378
mlir::MLIRContext *context = module.getContext();
286-
auto recId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
379+
auto [nameKind, sourceName] = fir::NameUniquer::deconstruct(Ty.getName());
380+
if (nameKind != fir::NameUniquer::NameKind::DERIVED_TYPE)
381+
return genPlaceholderType(context);
382+
383+
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
287384
// Generate a place holder TypeAttr which will be used if a member
288385
// references the parent type.
289-
auto comAttr = mlir::LLVM::DICompositeTypeAttr::get(
386+
auto recId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
387+
auto placeHolder = mlir::LLVM::DICompositeTypeAttr::get(
290388
context, recId, /*isRecSelf=*/true, llvm::dwarf::DW_TAG_structure_type,
291389
mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope,
292390
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0,
293391
/*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
294392
/*allocated=*/nullptr, /*associated=*/nullptr);
295-
typeCache[Ty] = comAttr;
296-
297-
auto result = fir::NameUniquer::deconstruct(Ty.getName());
298-
if (result.first != fir::NameUniquer::NameKind::DERIVED_TYPE)
299-
return genPlaceholderType(context);
393+
DerivedTypeCache::ActiveLevels nestedRecursions =
394+
derivedTypeCache.startTranslating(Ty, placeHolder);
300395

301396
fir::TypeInfoOp tiOp = symbolTable->lookup<fir::TypeInfoOp>(Ty.getName());
302397
unsigned line = (tiOp) ? getLineFromLoc(tiOp.getLoc()) : 1;
@@ -305,6 +400,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
305400
mlir::IntegerType intTy = mlir::IntegerType::get(context, 64);
306401
std::uint64_t offset = 0;
307402
for (auto [fieldName, fieldTy] : Ty.getTypeList()) {
403+
derivedTypeCache.preComponentVisitUpdate();
308404
auto [byteSize, byteAlign] = getFieldSizeAndAlign(fieldTy);
309405
std::optional<llvm::ArrayRef<int64_t>> lowerBounds =
310406
fir::getComponentLowerBoundsIfNonDefault(Ty, fieldName, module,
@@ -317,22 +413,22 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
317413
mlir::LLVM::DITypeAttr elemTy;
318414
if (lowerBounds && seqTy &&
319415
lowerBounds->size() == seqTy.getShape().size()) {
320-
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
416+
llvm::SmallVector<mlir::LLVM::DINodeAttr> arrayElements;
321417
for (auto [bound, dim] :
322418
llvm::zip_equal(*lowerBounds, seqTy.getShape())) {
323419
auto countAttr = mlir::IntegerAttr::get(intTy, llvm::APInt(64, dim));
324420
auto lowerAttr = mlir::IntegerAttr::get(intTy, llvm::APInt(64, bound));
325421
auto subrangeTy = mlir::LLVM::DISubrangeAttr::get(
326422
context, countAttr, lowerAttr, /*upperBound=*/nullptr,
327423
/*stride=*/nullptr);
328-
elements.push_back(subrangeTy);
424+
arrayElements.push_back(subrangeTy);
329425
}
330426
elemTy = mlir::LLVM::DICompositeTypeAttr::get(
331427
context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr,
332428
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr,
333429
convertType(seqTy.getEleTy(), fileAttr, scope, declOp),
334430
mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0,
335-
elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
431+
arrayElements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
336432
/*allocated=*/nullptr, /*associated=*/nullptr);
337433
} else
338434
elemTy = convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr);
@@ -344,96 +440,37 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
344440
/*extra data=*/nullptr);
345441
elements.push_back(tyAttr);
346442
offset += llvm::alignTo(byteSize, byteAlign);
347-
348-
// Currently, the handling of recursive debug type in mlir has some
349-
// limitations that were discussed at the end of the thread for following
350-
// PR.
351-
// https://github.com/llvm/llvm-project/pull/106571
352-
//
353-
// Problem could be explained with the following example code:
354-
// type t2
355-
// type(t1), pointer :: p1
356-
// end type
357-
// type t1
358-
// type(t2), pointer :: p2
359-
// end type
360-
// In the description below, type_self means a temporary type that is
361-
// generated
362-
// as a place holder while the members of that type are being processed.
363-
//
364-
// If we process t1 first then we will have the following structure after
365-
// it has been processed.
366-
// t1 -> t2 -> t1_self
367-
// This is because when we started processing t2, we did not have the
368-
// complete t1 but its place holder t1_self.
369-
// Now if some entity requires t2, we will already have that in cache and
370-
// will return it. But this t2 refers to t1_self and not to t1. In mlir
371-
// handling, only those types are allowed to have _self reference which are
372-
// wrapped by entity whose reference it is. So t1 -> t2 -> t1_self is ok
373-
// because the t1_self reference can be resolved by the outer t1. But
374-
// standalone t2 is not because there will be no way to resolve it. Until
375-
// this is fixed in mlir, we avoid caching such types. Please see
376-
// DebugTranslation::translateRecursive for details on how mlir handles
377-
// recursive types.
378-
// The code below checks for situation where it will be unsafe to cache
379-
// a type to avoid this problem. We do that in 2 situations.
380-
// 1. If a member is record type, then its type would have been processed
381-
// before reaching here. If it is not in the cache, it means that it was
382-
// found to be unsafe to cache. So any type containing it will also not
383-
// be cached
384-
// 2. The type of the member is found in the cache but it is a place holder.
385-
// In this case, its recID should match the recID of the type we are
386-
// processing. This helps us to cache the following type.
387-
// type t
388-
// type(t), allocatable :: p
389-
// end type
390-
mlir::Type baseTy = getDerivedType(fieldTy);
391-
if (auto recTy = mlir::dyn_cast<fir::RecordType>(baseTy)) {
392-
auto iter = typeCache.find(recTy);
393-
if (iter == typeCache.end())
394-
canCacheThisType = false;
395-
else {
396-
if (auto tyAttr =
397-
mlir::dyn_cast<mlir::LLVM::DICompositeTypeAttr>(iter->second)) {
398-
if (tyAttr.getIsRecSelf() && tyAttr.getRecId() != recId)
399-
canCacheThisType = false;
400-
}
401-
}
402-
}
443+
derivedTypeCache.postComponentVisitUpdate(nestedRecursions);
403444
}
404445

405446
auto finalAttr = mlir::LLVM::DICompositeTypeAttr::get(
406447
context, recId, /*isRecSelf=*/false, llvm::dwarf::DW_TAG_structure_type,
407-
mlir::StringAttr::get(context, result.second.name), fileAttr, line, scope,
448+
mlir::StringAttr::get(context, sourceName.name), fileAttr, line, scope,
408449
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8,
409450
/*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
410451
/*allocated=*/nullptr, /*associated=*/nullptr);
411452

412-
// derivedTypeDepth == 1 means that it is a top level type which is safe to
413-
// cache.
414-
if (canCacheThisType || derivedTypeDepth == 1) {
415-
typeCache[Ty] = finalAttr;
416-
} else {
417-
auto iter = typeCache.find(Ty);
418-
if (iter != typeCache.end())
419-
typeCache.erase(iter);
420-
}
453+
derivedTypeCache.finalize(Ty, finalAttr, std::move(nestedRecursions));
454+
421455
return finalAttr;
422456
}
423457

424458
mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
425459
mlir::TupleType Ty, mlir::LLVM::DIFileAttr fileAttr,
426460
mlir::LLVM::DIScopeAttr scope, fir::cg::XDeclareOp declOp) {
427461
// Check if this type has already been converted.
428-
auto iter = typeCache.find(Ty);
429-
if (iter != typeCache.end())
430-
return iter->second;
462+
if (mlir::LLVM::DITypeAttr attr = derivedTypeCache.lookup(Ty))
463+
return attr;
464+
465+
DerivedTypeCache::ActiveLevels nestedRecursions =
466+
derivedTypeCache.startTranslating(Ty);
431467

432468
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
433469
mlir::MLIRContext *context = module.getContext();
434470

435471
std::uint64_t offset = 0;
436472
for (auto fieldTy : Ty.getTypes()) {
473+
derivedTypeCache.preComponentVisitUpdate();
437474
auto [byteSize, byteAlign] = getFieldSizeAndAlign(fieldTy);
438475
mlir::LLVM::DITypeAttr elemTy =
439476
convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr);
@@ -445,6 +482,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
445482
/*extra data=*/nullptr);
446483
elements.push_back(tyAttr);
447484
offset += llvm::alignTo(byteSize, byteAlign);
485+
derivedTypeCache.postComponentVisitUpdate(nestedRecursions);
448486
}
449487

450488
auto typeAttr = mlir::LLVM::DICompositeTypeAttr::get(
@@ -453,7 +491,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
453491
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8,
454492
/*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
455493
/*allocated=*/nullptr, /*associated=*/nullptr);
456-
typeCache[Ty] = typeAttr;
494+
derivedTypeCache.finalize(Ty, typeAttr, std::move(nestedRecursions));
457495
return typeAttr;
458496
}
459497

@@ -667,27 +705,7 @@ DebugTypeGenerator::convertType(mlir::Type Ty, mlir::LLVM::DIFileAttr fileAttr,
667705
return convertCharacterType(charTy, fileAttr, scope, declOp,
668706
/*hasDescriptor=*/false);
669707
} else if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(Ty)) {
670-
// For nested derived types like shown below, the call sequence of the
671-
// convertRecordType will look something like as follows:
672-
// convertRecordType (t1)
673-
// convertRecordType (t2)
674-
// convertRecordType (t3)
675-
// We need to recognize when we are processing the top level type like t1
676-
// to make caching decision. The variable `derivedTypeDepth` is used for
677-
// this purpose and maintains the current depth of derived type processing.
678-
// type t1
679-
// type(t2), pointer :: p1
680-
// end type
681-
// type t2
682-
// type(t3), pointer :: p2
683-
// end type
684-
// type t2
685-
// integer a
686-
// end type
687-
derivedTypeDepth++;
688-
auto result = convertRecordType(recTy, fileAttr, scope, declOp);
689-
derivedTypeDepth--;
690-
return result;
708+
return convertRecordType(recTy, fileAttr, scope, declOp);
691709
} else if (auto tupleTy = mlir::dyn_cast_if_present<mlir::TupleType>(Ty)) {
692710
return convertTupleType(tupleTy, fileAttr, scope, declOp);
693711
} else if (auto refTy = mlir::dyn_cast_if_present<fir::ReferenceType>(Ty)) {

0 commit comments

Comments
 (0)