Skip to content

Commit 9570477

Browse files
authored
[Custom Descriptors] Fix desc casts in AbstractTypeRefining (#7742)
When a described type is never allocated, AbstractTypeRefining will optimize it to the bottom heap type `none`. However, if that type is used as the target in any descriptor casts, subsequent refinalization will restore the original type as derived from the type of the descriptor operand. To avoid this reintroduction of optimized types, fix up affected descriptor casts to be non-descriptor casts before refinalizing. This is valid because we know that only null values can ever pass these casts, so the descriptors are never used. As a drive by, use a more efficient (and less verbose) method of collecting the public types in AbstractTypeRefining as well.
1 parent 6c54803 commit 9570477

File tree

2 files changed

+700
-8
lines changed

2 files changed

+700
-8
lines changed

src/passes/AbstractTypeRefining.cpp

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
// must fail unless it allows null.
3434
//
3535

36+
#include <memory>
37+
38+
#include "ir/localize.h"
3639
#include "ir/module-utils.h"
3740
#include "ir/subtypes.h"
3841
#include "ir/type-updating.h"
@@ -112,14 +115,8 @@ struct AbstractTypeRefining : public Pass {
112115
// module, given closed world, but we'd also need to make sure that
113116
// we don't need to make any changes to public types that refer to
114117
// them.
115-
auto heapTypes = ModuleUtils::collectHeapTypeInfo(
116-
*module,
117-
ModuleUtils::TypeInclusion::AllTypes,
118-
ModuleUtils::VisibilityHandling::FindVisibility);
119-
for (auto& [type, info] : heapTypes) {
120-
if (info.visibility == ModuleUtils::Visibility::Public) {
121-
createdTypes.insert(type);
122-
}
118+
for (auto type : ModuleUtils::getPublicHeapTypes(*module)) {
119+
createdTypes.insert(type);
123120
}
124121

125122
SubTypes subTypes(*module);
@@ -243,6 +240,12 @@ struct AbstractTypeRefining : public Pass {
243240

244241
TypeMapper::TypeUpdates mapping;
245242

243+
// Track whether we optimize any described types to bottom. If we do, then
244+
// we could end up with descriptor casts to nullref, which need to be fixed
245+
// up before ReFinalize reintroduces the cast result type that was supposed
246+
// to be optimized out.
247+
bool optimizedDescribedType = false;
248+
246249
for (auto type : subTypes.types) {
247250
if (!type.isStruct()) {
248251
// TODO: support arrays and funcs
@@ -259,6 +262,9 @@ struct AbstractTypeRefining : public Pass {
259262
// We check this first as it is the most powerful change.
260263
if (createdTypesOrSubTypes.count(type) == 0) {
261264
mapping[type] = type.getBottom();
265+
if (type.getDescriptorType()) {
266+
optimizedDescribedType = true;
267+
}
262268
continue;
263269
}
264270

@@ -297,11 +303,69 @@ struct AbstractTypeRefining : public Pass {
297303

298304
AbstractTypeRefiningTypeMapper(*module, mapping).map();
299305

306+
if (optimizedDescribedType) {
307+
// At this point we may have casts like this:
308+
//
309+
// (ref.cast_desc nullref
310+
// (some struct...)
311+
// (some desc...)
312+
// )
313+
//
314+
// ReFinalize would fix up the cast target to be the struct type described
315+
// by the descriptor, but that struct type was supposed to have been
316+
// optimized out. Optimize out the cast (which we know must either be a
317+
// null check or unconditional trap) before ReFinalize can get to it.
318+
fixupDescriptorCasts(*module);
319+
}
320+
300321
// Refinalize to propagate the type changes we made. For example, a refined
301322
// cast may lead to a struct.get reading a more refined type using that
302323
// type.
303324
ReFinalize().run(getPassRunner(), module);
304325
}
326+
327+
void fixupDescriptorCasts(Module& module) {
328+
struct CastFixer : WalkerPass<PostWalker<CastFixer>> {
329+
bool isFunctionParallel() override { return true; }
330+
std::unique_ptr<Pass> create() override {
331+
return std::make_unique<CastFixer>();
332+
}
333+
void visitRefCast(RefCast* curr) {
334+
if (!curr->desc || !curr->type.isNull()) {
335+
return;
336+
}
337+
Block* replacement =
338+
ChildLocalizer(curr, getFunction(), *getModule(), getPassOptions())
339+
.getChildrenReplacement();
340+
// Reuse `curr` as the cast to nullref. Leave further optimization of
341+
// the cast to OptimizeInstructions.
342+
curr->desc = nullptr;
343+
replacement->list.push_back(curr);
344+
replacement->type = curr->type;
345+
replaceCurrent(replacement);
346+
}
347+
void visitBrOn(BrOn* curr) {
348+
if (curr->op != BrOnCastDesc && curr->op != BrOnCastDescFail) {
349+
return;
350+
}
351+
if (!curr->castType.isNull()) {
352+
return;
353+
}
354+
bool isFail = curr->op == BrOnCastDescFail;
355+
Block* replacement =
356+
ChildLocalizer(curr, getFunction(), *getModule(), getPassOptions())
357+
.getChildrenReplacement();
358+
// Reuse `curr` as a br_on_cast to nullref. Leave further optimization
359+
// of the branch to RemoveUnusedBrs.
360+
curr->desc = nullptr;
361+
curr->op = isFail ? BrOnCastFail : BrOnCast;
362+
replacement->list.push_back(curr);
363+
replacement->type = curr->type;
364+
replaceCurrent(replacement);
365+
}
366+
} fixer;
367+
fixer.run(getPassRunner(), &module);
368+
}
305369
};
306370

307371
} // anonymous namespace

0 commit comments

Comments
 (0)