Skip to content

Commit 89196f2

Browse files
authored
[Custom Descriptors] Optionally optimize RefCast to RefCastDesc in GlobalStructInference (#7906)
1 parent a8ccf99 commit 89196f2

File tree

2 files changed

+451
-0
lines changed

2 files changed

+451
-0
lines changed

src/passes/GlobalStructInference.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@
5454
//
5555
// TODO: Only do the case with a select when shrinkLevel == 0?
5656
//
57+
// --pass-arg=gsi-desc-casts
58+
//
59+
// Optimize casts to descriptor casts when possible. If a cast has no
60+
// relevant subtypes, and it has a known descriptor, then we can do a
61+
// ref.cast_desc instead, which can be faster (but is larger, so this is
62+
// not on by default yet).
63+
//
5764

5865
#include <variant>
5966

@@ -89,11 +96,21 @@ struct GlobalStructInference : public Pass {
8996
// type-based inference, and this remains empty.
9097
std::unordered_map<HeapType, std::vector<Name>> typeGlobals;
9198

99+
bool optimizeToDescCasts;
100+
101+
std::unique_ptr<SubTypes> subTypes;
102+
92103
void run(Module* module) override {
93104
if (!module->features.hasGC()) {
94105
return;
95106
}
96107

108+
optimizeToDescCasts = hasArgument("gsi-desc-casts");
109+
if (optimizeToDescCasts) {
110+
// We need subtypes to know when to optimize to a desc cast.
111+
subTypes = std::make_unique<SubTypes>(*module);
112+
}
113+
97114
if (getPassOptions().closedWorld) {
98115
analyzeClosedWorld(module);
99116
}
@@ -498,6 +515,55 @@ struct GlobalStructInference : public Pass {
498515
right));
499516
}
500517

518+
void visitRefCast(RefCast* curr) {
519+
// When we see (ref.cast $T), and the type has a descriptor, and that
520+
// descriptor only has a single global, then we can do (ref.cast_desc)
521+
// using the descriptor. Descriptor casts are usually more efficient
522+
// than normal ones (and even more so if we get lucky and are in a loop,
523+
// where the global.get of the descriptor can be hoisted).
524+
// TODO: only do this when shrinkLevel == 0?
525+
if (!parent.optimizeToDescCasts) {
526+
return;
527+
}
528+
529+
// Check if we have a descriptor.
530+
auto type = curr->type;
531+
if (type == Type::unreachable) {
532+
return;
533+
}
534+
auto heapType = type.getHeapType();
535+
auto desc = heapType.getDescriptorType();
536+
if (!desc) {
537+
return;
538+
}
539+
540+
// Check if the type has no (relevant) subtypes, as a ref.cast_desc will
541+
// find precisely that type and nothing else.
542+
if (!type.isExact() &&
543+
!parent.subTypes->getStrictSubTypes(heapType).empty()) {
544+
return;
545+
}
546+
547+
// Check if we have a single global for the descriptor.
548+
auto iter = parent.typeGlobals.find(*desc);
549+
if (iter == parent.typeGlobals.end()) {
550+
return;
551+
}
552+
const auto& globals = iter->second;
553+
if (globals.size() != 1) {
554+
return;
555+
}
556+
557+
// We can optimize!
558+
auto global = globals[0];
559+
auto& wasm = *getModule();
560+
Builder builder(wasm);
561+
auto* getGlobal =
562+
builder.makeGlobalGet(global, wasm.getGlobal(global)->type);
563+
auto* castDesc = builder.makeRefCast(curr->ref, getGlobal, curr->type);
564+
replaceCurrent(castDesc);
565+
}
566+
501567
void visitFunction(Function* func) {
502568
if (refinalize) {
503569
ReFinalize().walkFunctionInModule(func, getModule());

0 commit comments

Comments
 (0)