36
36
// passes in between.
37
37
//
38
38
39
+ #include < algorithm>
40
+ #include < unordered_map>
41
+
39
42
#include " ir/module-utils.h"
40
43
#include " ir/type-updating.h"
41
44
#include " ir/utils.h"
44
47
#include " support/small_set.h"
45
48
#include " wasm-builder.h"
46
49
#include " wasm-type-ordering.h"
50
+ #include " wasm-type.h"
47
51
#include " wasm.h"
48
52
49
53
#define TYPE_MERGING_DEBUG 0
@@ -115,6 +119,17 @@ struct CastFinder : public PostWalker<CastFinder> {
115
119
}
116
120
};
117
121
122
+ HeapType getBaseDescribedType (HeapType type) {
123
+ while (true ) {
124
+ if (auto next = type.getDescribedType ()) {
125
+ type = *next;
126
+ continue ;
127
+ }
128
+ break ;
129
+ }
130
+ return type;
131
+ }
132
+
118
133
// We are going to treat the type graph as a partitioned DFA where each type is
119
134
// a state with transitions to its children. We will partition the DFA states so
120
135
// that types that may be mergeable will be in the same partition and types that
@@ -339,22 +354,43 @@ bool TypeMerging::merge(MergeKind kind) {
339
354
// For each type, either create a new partition or add to its supertype's
340
355
// partition.
341
356
for (auto type : mergeableSupertypesFirst (mergeable)) {
357
+ // Skip descriptor types. Since types in descriptor chains all have to be
358
+ // merged into matching descriptor chains together, only the base described
359
+ // type in each chain is considered, and its DFA state will include the
360
+ // shape of its entire descriptor chain.
361
+ if (type.getDescribedType ()) {
362
+ continue ;
363
+ }
342
364
// We need partitions for any public children of this type since those
343
- // children will participate in the DFA we're creating.
344
- for (auto child : getPublicChildren (type)) {
345
- ensurePartition (child);
365
+ // children will participate in the DFA we're creating. We use the base
366
+ // described type of the child because that's the type that the DFA state
367
+ // for the current type will point to.
368
+ for (auto t : type.getDescriptorChain ()) {
369
+ for (auto child : getPublicChildren (t)) {
370
+ ensurePartition (getBaseDescribedType (child));
371
+ }
346
372
}
347
373
// If the type is distinguished by the module or public, we cannot merge it,
348
374
// so create a new partition for it.
349
- if (castTypes.count (type) || !privateTypes.count (type)) {
375
+ auto chain = type.getDescriptorChain ();
376
+ bool hasCast =
377
+ std::any_of (chain.begin (), chain.end (), [&](HeapType t) -> bool {
378
+ return castTypes.count (type);
379
+ });
380
+ if (hasCast || !privateTypes.count (type)) {
350
381
ensurePartition (type);
351
382
continue ;
352
383
}
353
384
354
385
switch (kind) {
355
386
case Supertypes: {
356
387
auto super = type.getDeclaredSuperType ();
357
- bool superHasExactCast = super && exactCastTypes.count (*super);
388
+ bool superHasExactCast =
389
+ super &&
390
+ std::any_of (chain.begin (), chain.end (), [&](HeapType t) -> bool {
391
+ auto super = t.getDeclaredSuperType ();
392
+ return super && exactCastTypes.count (*super);
393
+ });
358
394
if (!super || !shapeEq (type, *super) || superHasExactCast) {
359
395
// Create a new partition for this type and bail.
360
396
ensurePartition (type);
@@ -556,10 +592,20 @@ DFA::State<HeapType> TypeMerging::makeDFAState(HeapType type) {
556
592
// other direction, including the children is not necessary to differentiate
557
593
// types reached by the public types because all such reachable types are also
558
594
// public and not eligible to be merged.
595
+ //
596
+ // For private types, full descriptor chains are included in a single DFA
597
+ // represented by their base described type.
559
598
if (privateTypes.count (type)) {
560
- for (auto child : type.getHeapTypeChildren ()) {
561
- if (!child.isBasic ()) {
562
- succs.push_back (getMerged (child));
599
+ assert (!type.getDescribedType ());
600
+ for (auto t : type.getDescriptorChain ()) {
601
+ for (auto child : t.getHeapTypeChildren ()) {
602
+ if (!child.isBasic ()) {
603
+ // The child's partition is represented by the base of its descriptor
604
+ // chain. Different child types in the same descriptor chain are
605
+ // differentiated by including their chain index in the hashed
606
+ // top-level shape of the parent.
607
+ succs.push_back (getMerged (getBaseDescribedType (child)));
608
+ }
563
609
}
564
610
}
565
611
}
@@ -571,77 +617,104 @@ void TypeMerging::applyMerges() {
571
617
return ;
572
618
}
573
619
574
- // Flatten merges, which might be an arbitrary tree at this point.
620
+ // Flatten merges, which might be an arbitrary tree at this point. Also expand
621
+ // the mapping to cover every type in each descriptor chain.
622
+ std::unordered_map<HeapType, HeapType> replacements;
575
623
for (auto [type, _] : merges) {
576
- merges[type] = getMerged (type);
624
+ auto target = getMerged (type);
625
+ auto chain = type.getDescriptorChain ();
626
+ auto targetChain = target.getDescriptorChain ();
627
+ auto targetIt = targetChain.begin ();
628
+ for (auto it = chain.begin (); it != chain.end (); ++it) {
629
+ assert (targetIt != targetChain.end ());
630
+ replacements[*it] = *targetIt++;
631
+ }
577
632
}
578
633
579
634
// We found things to optimize! Rewrite types in the module to apply those
580
635
// changes.
581
- TypeMapper (*module , merges ).map ();
636
+ TypeMapper (*module , replacements ).map ();
582
637
}
583
638
584
639
bool shapeEq (HeapType a, HeapType b) {
585
640
// Check whether `a` and `b` have the same top-level structure, including the
586
641
// position and identity of any children that are not included as transitions
587
- // in the DFA, i.e. any children that are not nontrivial references.
588
- if (a.isOpen () != b.isOpen ()) {
589
- return false ;
590
- }
591
- if (a.isShared () != b.isShared ()) {
592
- return false ;
593
- }
594
- // Ignore supertype because we want to be able to merge into parents.
595
- if (!!a.getDescriptorType () != !!b.getDescriptorType ()) {
596
- return false ;
597
- }
598
- if (!!a.getDescribedType () != !!b.getDescribedType ()) {
599
- return false ;
600
- }
601
- auto aKind = a.getKind ();
602
- auto bKind = b.getKind ();
603
- if (aKind != bKind) {
604
- return false ;
642
+ // in the DFA, i.e. any children that are not nontrivial references. We treat
643
+ // full descriptor chains as single units, so compare the shape of every type
644
+ // in the chains rooted at `a` and `b`.
645
+ assert (!a.getDescribedType () && !b.getDescribedType ());
646
+ auto chainA = a.getDescriptorChain ();
647
+ auto chainB = b.getDescriptorChain ();
648
+ auto itA = chainA.begin ();
649
+ auto itB = chainB.begin ();
650
+ while (itA != chainA.end () && itB != chainB.end ()) {
651
+ a = *itA++;
652
+ b = *itB++;
653
+ if (a.isOpen () != b.isOpen ()) {
654
+ return false ;
655
+ }
656
+ if (a.isShared () != b.isShared ()) {
657
+ return false ;
658
+ }
659
+ // Ignore supertype because we want to be able to merge into parents.
660
+ auto aKind = a.getKind ();
661
+ auto bKind = b.getKind ();
662
+ if (aKind != bKind) {
663
+ return false ;
664
+ }
665
+ switch (aKind) {
666
+ case HeapTypeKind::Func:
667
+ if (!shapeEq (a.getSignature (), b.getSignature ())) {
668
+ return false ;
669
+ }
670
+ break ;
671
+ case HeapTypeKind::Struct:
672
+ if (!shapeEq (a.getStruct (), b.getStruct ())) {
673
+ return false ;
674
+ }
675
+ break ;
676
+ case HeapTypeKind::Array:
677
+ if (!shapeEq (a.getArray (), b.getArray ())) {
678
+ return false ;
679
+ }
680
+ break ;
681
+ case HeapTypeKind::Cont:
682
+ WASM_UNREACHABLE (" TODO: cont" );
683
+ case HeapTypeKind::Basic:
684
+ WASM_UNREACHABLE (" unexpected kind" );
685
+ }
605
686
}
606
- switch (aKind) {
607
- case HeapTypeKind::Func:
608
- return shapeEq (a.getSignature (), b.getSignature ());
609
- case HeapTypeKind::Struct:
610
- return shapeEq (a.getStruct (), b.getStruct ());
611
- case HeapTypeKind::Array:
612
- return shapeEq (a.getArray (), b.getArray ());
613
- case HeapTypeKind::Cont:
614
- WASM_UNREACHABLE (" TODO: cont" );
615
- case HeapTypeKind::Basic:
616
- WASM_UNREACHABLE (" unexpected kind" );
617
- }
618
- return false ;
687
+ return itA == chainA.end () && itB == chainB.end ();
619
688
}
620
689
621
690
size_t shapeHash (HeapType a) {
622
- size_t digest = hash (a.isOpen ());
623
- rehash (digest, a.isShared ());
624
- // Ignore supertype because we want to be able to merge into parents.
625
- rehash (digest, !!a.getDescriptorType ());
626
- rehash (digest, !!a.getDescribedType ());
627
- auto kind = a.getKind ();
628
- rehash (digest, kind);
629
- switch (kind) {
630
- case HeapTypeKind::Func:
631
- hash_combine (digest, shapeHash (a.getSignature ()));
632
- return digest;
633
- case HeapTypeKind::Struct:
634
- hash_combine (digest, shapeHash (a.getStruct ()));
635
- return digest;
636
- case HeapTypeKind::Array:
637
- hash_combine (digest, shapeHash (a.getArray ()));
638
- return digest;
639
- case HeapTypeKind::Cont:
640
- WASM_UNREACHABLE (" TODO: cont" );
641
- case HeapTypeKind::Basic:
642
- break ;
691
+ assert (!a.getDescribedType ());
692
+ size_t digest = 0xA76F35EC ;
693
+ for (auto type : a.getDescriptorChain ()) {
694
+ rehash (digest, 0xCC6B0DD9 );
695
+ rehash (digest, type.isOpen ());
696
+ rehash (digest, type.isShared ());
697
+ // Ignore supertype because we want to be able to merge into parents.
698
+ auto kind = type.getKind ();
699
+ rehash (digest, kind);
700
+ switch (kind) {
701
+ case HeapTypeKind::Func:
702
+ hash_combine (digest, shapeHash (type.getSignature ()));
703
+ continue ;
704
+ case HeapTypeKind::Struct:
705
+ hash_combine (digest, shapeHash (type.getStruct ()));
706
+ continue ;
707
+ case HeapTypeKind::Array:
708
+ hash_combine (digest, shapeHash (type.getArray ()));
709
+ continue ;
710
+ case HeapTypeKind::Cont:
711
+ WASM_UNREACHABLE (" TODO: cont" );
712
+ case HeapTypeKind::Basic:
713
+ continue ;
714
+ }
715
+ WASM_UNREACHABLE (" unexpected kind" );
643
716
}
644
- WASM_UNREACHABLE ( " unexpected kind " ) ;
717
+ return digest ;
645
718
}
646
719
647
720
bool shapeEq (const Struct& a, const Struct& b) {
@@ -690,6 +763,18 @@ size_t shapeHash(Field a) {
690
763
return digest;
691
764
}
692
765
766
+ Index chainIndex (HeapType type) {
767
+ Index i = 0 ;
768
+ while (true ) {
769
+ if (auto next = type.getDescribedType ()) {
770
+ type = *next;
771
+ ++i;
772
+ continue ;
773
+ }
774
+ return i;
775
+ }
776
+ }
777
+
693
778
bool shapeEq (Type a, Type b) {
694
779
if (a == b) {
695
780
return true ;
@@ -713,6 +798,13 @@ bool shapeEq(Type a, Type b) {
713
798
if (a.getExactness () != b.getExactness ()) {
714
799
return false ;
715
800
}
801
+ // Since partition refinement treats descriptor chains as units, it cannot
802
+ // differentiate between different types in the same chain. Two types in the
803
+ // same chain will never be merged, so we can differentiate them here by index
804
+ // in their chain instead.
805
+ if (chainIndex (a.getHeapType ()) != chainIndex (b.getHeapType ())) {
806
+ return false ;
807
+ }
716
808
return true ;
717
809
}
718
810
@@ -735,6 +827,7 @@ size_t shapeHash(Type a) {
735
827
rehash (digest, 4 );
736
828
rehash (digest, (int )a.getNullability ());
737
829
rehash (digest, (int )a.getExactness ());
830
+ rehash (digest, chainIndex (a.getHeapType ()));
738
831
return digest;
739
832
}
740
833
0 commit comments