Skip to content

Commit 1d2e23d

Browse files
authored
[Custom Descriptors] Merge descriptors correctly (#7709)
TypeMerging cannot merge one type in a descriptor chain into another type without merging the type's full descriptor chain into the other type's full descriptor chain. Because of this, each descriptor chain acts as a single unit in the DFA minimization algorithm. Update TypeMerging so that descriptor types do not get their own shapes, but rather are included in the shapes of their base described types. To make this simpler, add a new utility for easily iterating over a type's descriptor chain.
1 parent ad67cf4 commit 1d2e23d

File tree

3 files changed

+503
-87
lines changed

3 files changed

+503
-87
lines changed

src/passes/TypeMerging.cpp

Lines changed: 157 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
// passes in between.
3737
//
3838

39+
#include <algorithm>
40+
#include <unordered_map>
41+
3942
#include "ir/module-utils.h"
4043
#include "ir/type-updating.h"
4144
#include "ir/utils.h"
@@ -44,6 +47,7 @@
4447
#include "support/small_set.h"
4548
#include "wasm-builder.h"
4649
#include "wasm-type-ordering.h"
50+
#include "wasm-type.h"
4751
#include "wasm.h"
4852

4953
#define TYPE_MERGING_DEBUG 0
@@ -115,6 +119,17 @@ struct CastFinder : public PostWalker<CastFinder> {
115119
}
116120
};
117121

122+
HeapType getBaseDescribedType(HeapType type) {
123+
while (true) {
124+
if (auto next = type.getDescribedType()) {
125+
type = *next;
126+
continue;
127+
}
128+
break;
129+
}
130+
return type;
131+
}
132+
118133
// We are going to treat the type graph as a partitioned DFA where each type is
119134
// a state with transitions to its children. We will partition the DFA states so
120135
// that types that may be mergeable will be in the same partition and types that
@@ -339,22 +354,43 @@ bool TypeMerging::merge(MergeKind kind) {
339354
// For each type, either create a new partition or add to its supertype's
340355
// partition.
341356
for (auto type : mergeableSupertypesFirst(mergeable)) {
357+
// Skip descriptor types. Since types in descriptor chains all have to be
358+
// merged into matching descriptor chains together, only the base described
359+
// type in each chain is considered, and its DFA state will include the
360+
// shape of its entire descriptor chain.
361+
if (type.getDescribedType()) {
362+
continue;
363+
}
342364
// We need partitions for any public children of this type since those
343-
// children will participate in the DFA we're creating.
344-
for (auto child : getPublicChildren(type)) {
345-
ensurePartition(child);
365+
// children will participate in the DFA we're creating. We use the base
366+
// described type of the child because that's the type that the DFA state
367+
// for the current type will point to.
368+
for (auto t : type.getDescriptorChain()) {
369+
for (auto child : getPublicChildren(t)) {
370+
ensurePartition(getBaseDescribedType(child));
371+
}
346372
}
347373
// If the type is distinguished by the module or public, we cannot merge it,
348374
// so create a new partition for it.
349-
if (castTypes.count(type) || !privateTypes.count(type)) {
375+
auto chain = type.getDescriptorChain();
376+
bool hasCast =
377+
std::any_of(chain.begin(), chain.end(), [&](HeapType t) -> bool {
378+
return castTypes.count(type);
379+
});
380+
if (hasCast || !privateTypes.count(type)) {
350381
ensurePartition(type);
351382
continue;
352383
}
353384

354385
switch (kind) {
355386
case Supertypes: {
356387
auto super = type.getDeclaredSuperType();
357-
bool superHasExactCast = super && exactCastTypes.count(*super);
388+
bool superHasExactCast =
389+
super &&
390+
std::any_of(chain.begin(), chain.end(), [&](HeapType t) -> bool {
391+
auto super = t.getDeclaredSuperType();
392+
return super && exactCastTypes.count(*super);
393+
});
358394
if (!super || !shapeEq(type, *super) || superHasExactCast) {
359395
// Create a new partition for this type and bail.
360396
ensurePartition(type);
@@ -556,10 +592,20 @@ DFA::State<HeapType> TypeMerging::makeDFAState(HeapType type) {
556592
// other direction, including the children is not necessary to differentiate
557593
// types reached by the public types because all such reachable types are also
558594
// public and not eligible to be merged.
595+
//
596+
// For private types, full descriptor chains are included in a single DFA
597+
// represented by their base described type.
559598
if (privateTypes.count(type)) {
560-
for (auto child : type.getHeapTypeChildren()) {
561-
if (!child.isBasic()) {
562-
succs.push_back(getMerged(child));
599+
assert(!type.getDescribedType());
600+
for (auto t : type.getDescriptorChain()) {
601+
for (auto child : t.getHeapTypeChildren()) {
602+
if (!child.isBasic()) {
603+
// The child's partition is represented by the base of its descriptor
604+
// chain. Different child types in the same descriptor chain are
605+
// differentiated by including their chain index in the hashed
606+
// top-level shape of the parent.
607+
succs.push_back(getMerged(getBaseDescribedType(child)));
608+
}
563609
}
564610
}
565611
}
@@ -571,77 +617,104 @@ void TypeMerging::applyMerges() {
571617
return;
572618
}
573619

574-
// Flatten merges, which might be an arbitrary tree at this point.
620+
// Flatten merges, which might be an arbitrary tree at this point. Also expand
621+
// the mapping to cover every type in each descriptor chain.
622+
std::unordered_map<HeapType, HeapType> replacements;
575623
for (auto [type, _] : merges) {
576-
merges[type] = getMerged(type);
624+
auto target = getMerged(type);
625+
auto chain = type.getDescriptorChain();
626+
auto targetChain = target.getDescriptorChain();
627+
auto targetIt = targetChain.begin();
628+
for (auto it = chain.begin(); it != chain.end(); ++it) {
629+
assert(targetIt != targetChain.end());
630+
replacements[*it] = *targetIt++;
631+
}
577632
}
578633

579634
// We found things to optimize! Rewrite types in the module to apply those
580635
// changes.
581-
TypeMapper(*module, merges).map();
636+
TypeMapper(*module, replacements).map();
582637
}
583638

584639
bool shapeEq(HeapType a, HeapType b) {
585640
// Check whether `a` and `b` have the same top-level structure, including the
586641
// position and identity of any children that are not included as transitions
587-
// in the DFA, i.e. any children that are not nontrivial references.
588-
if (a.isOpen() != b.isOpen()) {
589-
return false;
590-
}
591-
if (a.isShared() != b.isShared()) {
592-
return false;
593-
}
594-
// Ignore supertype because we want to be able to merge into parents.
595-
if (!!a.getDescriptorType() != !!b.getDescriptorType()) {
596-
return false;
597-
}
598-
if (!!a.getDescribedType() != !!b.getDescribedType()) {
599-
return false;
600-
}
601-
auto aKind = a.getKind();
602-
auto bKind = b.getKind();
603-
if (aKind != bKind) {
604-
return false;
642+
// in the DFA, i.e. any children that are not nontrivial references. We treat
643+
// full descriptor chains as single units, so compare the shape of every type
644+
// in the chains rooted at `a` and `b`.
645+
assert(!a.getDescribedType() && !b.getDescribedType());
646+
auto chainA = a.getDescriptorChain();
647+
auto chainB = b.getDescriptorChain();
648+
auto itA = chainA.begin();
649+
auto itB = chainB.begin();
650+
while (itA != chainA.end() && itB != chainB.end()) {
651+
a = *itA++;
652+
b = *itB++;
653+
if (a.isOpen() != b.isOpen()) {
654+
return false;
655+
}
656+
if (a.isShared() != b.isShared()) {
657+
return false;
658+
}
659+
// Ignore supertype because we want to be able to merge into parents.
660+
auto aKind = a.getKind();
661+
auto bKind = b.getKind();
662+
if (aKind != bKind) {
663+
return false;
664+
}
665+
switch (aKind) {
666+
case HeapTypeKind::Func:
667+
if (!shapeEq(a.getSignature(), b.getSignature())) {
668+
return false;
669+
}
670+
break;
671+
case HeapTypeKind::Struct:
672+
if (!shapeEq(a.getStruct(), b.getStruct())) {
673+
return false;
674+
}
675+
break;
676+
case HeapTypeKind::Array:
677+
if (!shapeEq(a.getArray(), b.getArray())) {
678+
return false;
679+
}
680+
break;
681+
case HeapTypeKind::Cont:
682+
WASM_UNREACHABLE("TODO: cont");
683+
case HeapTypeKind::Basic:
684+
WASM_UNREACHABLE("unexpected kind");
685+
}
605686
}
606-
switch (aKind) {
607-
case HeapTypeKind::Func:
608-
return shapeEq(a.getSignature(), b.getSignature());
609-
case HeapTypeKind::Struct:
610-
return shapeEq(a.getStruct(), b.getStruct());
611-
case HeapTypeKind::Array:
612-
return shapeEq(a.getArray(), b.getArray());
613-
case HeapTypeKind::Cont:
614-
WASM_UNREACHABLE("TODO: cont");
615-
case HeapTypeKind::Basic:
616-
WASM_UNREACHABLE("unexpected kind");
617-
}
618-
return false;
687+
return itA == chainA.end() && itB == chainB.end();
619688
}
620689

621690
size_t shapeHash(HeapType a) {
622-
size_t digest = hash(a.isOpen());
623-
rehash(digest, a.isShared());
624-
// Ignore supertype because we want to be able to merge into parents.
625-
rehash(digest, !!a.getDescriptorType());
626-
rehash(digest, !!a.getDescribedType());
627-
auto kind = a.getKind();
628-
rehash(digest, kind);
629-
switch (kind) {
630-
case HeapTypeKind::Func:
631-
hash_combine(digest, shapeHash(a.getSignature()));
632-
return digest;
633-
case HeapTypeKind::Struct:
634-
hash_combine(digest, shapeHash(a.getStruct()));
635-
return digest;
636-
case HeapTypeKind::Array:
637-
hash_combine(digest, shapeHash(a.getArray()));
638-
return digest;
639-
case HeapTypeKind::Cont:
640-
WASM_UNREACHABLE("TODO: cont");
641-
case HeapTypeKind::Basic:
642-
break;
691+
assert(!a.getDescribedType());
692+
size_t digest = 0xA76F35EC;
693+
for (auto type : a.getDescriptorChain()) {
694+
rehash(digest, 0xCC6B0DD9);
695+
rehash(digest, type.isOpen());
696+
rehash(digest, type.isShared());
697+
// Ignore supertype because we want to be able to merge into parents.
698+
auto kind = type.getKind();
699+
rehash(digest, kind);
700+
switch (kind) {
701+
case HeapTypeKind::Func:
702+
hash_combine(digest, shapeHash(type.getSignature()));
703+
continue;
704+
case HeapTypeKind::Struct:
705+
hash_combine(digest, shapeHash(type.getStruct()));
706+
continue;
707+
case HeapTypeKind::Array:
708+
hash_combine(digest, shapeHash(type.getArray()));
709+
continue;
710+
case HeapTypeKind::Cont:
711+
WASM_UNREACHABLE("TODO: cont");
712+
case HeapTypeKind::Basic:
713+
continue;
714+
}
715+
WASM_UNREACHABLE("unexpected kind");
643716
}
644-
WASM_UNREACHABLE("unexpected kind");
717+
return digest;
645718
}
646719

647720
bool shapeEq(const Struct& a, const Struct& b) {
@@ -690,6 +763,18 @@ size_t shapeHash(Field a) {
690763
return digest;
691764
}
692765

766+
Index chainIndex(HeapType type) {
767+
Index i = 0;
768+
while (true) {
769+
if (auto next = type.getDescribedType()) {
770+
type = *next;
771+
++i;
772+
continue;
773+
}
774+
return i;
775+
}
776+
}
777+
693778
bool shapeEq(Type a, Type b) {
694779
if (a == b) {
695780
return true;
@@ -713,6 +798,13 @@ bool shapeEq(Type a, Type b) {
713798
if (a.getExactness() != b.getExactness()) {
714799
return false;
715800
}
801+
// Since partition refinement treats descriptor chains as units, it cannot
802+
// differentiate between different types in the same chain. Two types in the
803+
// same chain will never be merged, so we can differentiate them here by index
804+
// in their chain instead.
805+
if (chainIndex(a.getHeapType()) != chainIndex(b.getHeapType())) {
806+
return false;
807+
}
716808
return true;
717809
}
718810

@@ -735,6 +827,7 @@ size_t shapeHash(Type a) {
735827
rehash(digest, 4);
736828
rehash(digest, (int)a.getNullability());
737829
rehash(digest, (int)a.getExactness());
830+
rehash(digest, chainIndex(a.getHeapType()));
738831
return digest;
739832
}
740833

src/wasm-type.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct Continuation;
5757
struct Field;
5858
struct Struct;
5959
struct Array;
60+
struct DescriptorChain;
6061

6162
using TypeList = std::vector<Type>;
6263
using Tuple = TypeList;
@@ -201,6 +202,7 @@ class HeapType {
201202
// Get this type's descriptor or described types if they exist.
202203
std::optional<HeapType> getDescriptorType() const;
203204
std::optional<HeapType> getDescribedType() const;
205+
DescriptorChain getDescriptorChain() const;
204206

205207
// Return the depth of this heap type in the nominal type hierarchy, i.e. the
206208
// number of supertypes in its supertype chain.
@@ -946,6 +948,49 @@ struct TypeBuilder {
946948
void dump();
947949
};
948950

951+
// An iterable providing access to a heap type's descriptor chain, starting from
952+
// itself and iterating through each successive descriptor type.
953+
struct DescriptorChain {
954+
HeapType base;
955+
struct Iterator {
956+
using iterator_category = std::forward_iterator_tag;
957+
using value_type = HeapType;
958+
using difference_type = std::ptrdiff_t;
959+
using pointer = const HeapType*;
960+
using reference = const HeapType&;
961+
962+
// The current type. An end iterator contains no type.
963+
std::optional<HeapType> type;
964+
965+
reference operator*() const { return *type; }
966+
967+
pointer operator->() const { return &*type; }
968+
969+
Iterator& operator++() {
970+
type = type->getDescriptorType();
971+
return *this;
972+
}
973+
974+
Iterator operator++(int) {
975+
Iterator it = *this;
976+
++(*this);
977+
return it;
978+
}
979+
980+
bool operator==(const Iterator& other) const { return type == other.type; }
981+
982+
bool operator!=(const Iterator& other) const { return !(*this == other); }
983+
};
984+
985+
Iterator begin() const { return Iterator{base}; }
986+
987+
Iterator end() const { return Iterator{std::nullopt}; }
988+
};
989+
990+
inline DescriptorChain HeapType::getDescriptorChain() const {
991+
return DescriptorChain{*this};
992+
}
993+
949994
// We consider certain specific types to always be public, to allow closed-
950995
// world to operate even if they escape. Specifically, "plain old data" types
951996
// like array of i8 and i16, which are used to represent strings, may cross

0 commit comments

Comments
 (0)