Skip to content

Commit acd28ee

Browse files
authored
Avoid public type collisions in GlobalTypeRewriter (#8139)
Move BrandTypeIterator from MinimizeRecGroups to wasm-type-shape.h and use it in a new UniqueRecGroups utility that can rebuild types to be distinct from previously seen rec groups. Use UniqueRecGroups in GlobalTypeRewriter to ensure the newly built private types do not conflict with public types. Split off from #8119 because this can land sooner.
1 parent 98b720d commit acd28ee

File tree

6 files changed

+227
-93
lines changed

6 files changed

+227
-93
lines changed

src/ir/type-updating.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,37 @@
2626

2727
namespace wasm {
2828

29-
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {}
30-
31-
void GlobalTypeRewriter::update() {
32-
mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors())));
33-
}
34-
35-
GlobalTypeRewriter::PredecessorGraph
36-
GlobalTypeRewriter::getPrivatePredecessors() {
29+
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm)
30+
: wasm(wasm), publicGroups(wasm.features) {
3731
// Find the heap types that are not publicly observable. Even in a closed
3832
// world scenario, don't modify public types because we assume that they may
3933
// be reflected on or used for linking. Figure out where each private type
4034
// will be located in the builder.
41-
auto typeInfo = ModuleUtils::collectHeapTypeInfo(
35+
typeInfo = ModuleUtils::collectHeapTypeInfo(
4236
wasm,
4337
ModuleUtils::TypeInclusion::UsedIRTypes,
4438
ModuleUtils::VisibilityHandling::FindVisibility);
4539

46-
// Check if a type is private, by looking up its info.
40+
std::unordered_set<RecGroup> seenGroups;
41+
for (auto& [type, info] : typeInfo) {
42+
if (info.visibility == ModuleUtils::Visibility::Public) {
43+
auto group = type.getRecGroup();
44+
if (seenGroups.insert(type.getRecGroup()).second) {
45+
std::vector<HeapType> groupTypes(group.begin(), group.end());
46+
publicGroups.insert(std::move(groupTypes));
47+
}
48+
}
49+
}
50+
}
51+
52+
void GlobalTypeRewriter::update() {
53+
mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors())));
54+
}
55+
56+
GlobalTypeRewriter::PredecessorGraph
57+
GlobalTypeRewriter::getPrivatePredecessors() {
58+
// Check if a type is private, looking for its info (if there is none, it is
59+
// not private).
4760
auto isPublic = [&](HeapType type) {
4861
auto it = typeInfo.find(type);
4962
assert(it != typeInfo.end());
@@ -185,11 +198,8 @@ GlobalTypeRewriter::rebuildTypes(std::vector<HeapType> types) {
185198
<< " at index " << err->index;
186199
}
187200
#endif
188-
auto& newTypes = *buildResults;
189-
190-
// TODO: It is possible that the newly built rec group matches some public rec
191-
// group. If that is the case, we need to try a different permutation of the
192-
// types or add a brand type to distinguish the private types.
201+
// Ensure the new types are different from any public rec group.
202+
const auto& newTypes = publicGroups.insert(*buildResults);
193203

194204
// Map the old types to the new ones.
195205
TypeMap oldToNewTypes;

src/ir/type-updating.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
#define wasm_ir_type_updating_h
1919

2020
#include "ir/branch-utils.h"
21+
#include "ir/module-utils.h"
2122
#include "support/insert_ordered.h"
2223
#include "wasm-traversal.h"
24+
#include "wasm-type-shape.h"
25+
#include "wasm-type.h"
2326

2427
namespace wasm {
2528

@@ -348,6 +351,13 @@ class GlobalTypeRewriter {
348351

349352
Module& wasm;
350353

354+
// The module's types and their visibilities.
355+
InsertOrderedMap<HeapType, ModuleUtils::HeapTypeInfo> typeInfo;
356+
357+
// The shapes of public rec groups, so we can be sure that the rewritten
358+
// private types do not conflict with public types.
359+
UniqueRecGroups publicGroups;
360+
351361
GlobalTypeRewriter(Module& wasm);
352362
virtual ~GlobalTypeRewriter() {}
353363

src/passes/MinimizeRecGroups.cpp

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -100,84 +100,6 @@ struct TypeSCCs
100100
}
101101
};
102102

103-
// After all their permutations with distinct shapes have been used, different
104-
// groups with the same shapes must be differentiated by adding in a "brand"
105-
// type. Even with a brand mixed in, we might run out of permutations with
106-
// distinct shapes, in which case we need a new brand type. This iterator
107-
// provides an infinite sequence of possible brand types, prioritizing those
108-
// with the most compact encoding.
109-
struct BrandTypeIterator {
110-
static constexpr Index optionCount = 18;
111-
static constexpr std::array<Field, optionCount> fieldOptions = {{
112-
Field(Field::i8, Mutable),
113-
Field(Field::i16, Mutable),
114-
Field(Type::i32, Mutable),
115-
Field(Type::i64, Mutable),
116-
Field(Type::f32, Mutable),
117-
Field(Type::f64, Mutable),
118-
Field(Type(HeapType::any, Nullable), Mutable),
119-
Field(Type(HeapType::func, Nullable), Mutable),
120-
Field(Type(HeapType::ext, Nullable), Mutable),
121-
Field(Type(HeapType::none, Nullable), Mutable),
122-
Field(Type(HeapType::nofunc, Nullable), Mutable),
123-
Field(Type(HeapType::noext, Nullable), Mutable),
124-
Field(Type(HeapType::any, NonNullable), Mutable),
125-
Field(Type(HeapType::func, NonNullable), Mutable),
126-
Field(Type(HeapType::ext, NonNullable), Mutable),
127-
Field(Type(HeapType::none, NonNullable), Mutable),
128-
Field(Type(HeapType::nofunc, NonNullable), Mutable),
129-
Field(Type(HeapType::noext, NonNullable), Mutable),
130-
}};
131-
132-
struct FieldInfo {
133-
uint8_t index = 0;
134-
bool immutable = false;
135-
136-
operator Field() const {
137-
auto field = fieldOptions[index];
138-
if (immutable) {
139-
field.mutable_ = Immutable;
140-
}
141-
return field;
142-
}
143-
144-
bool advance() {
145-
if (!immutable) {
146-
immutable = true;
147-
return true;
148-
}
149-
immutable = false;
150-
index = (index + 1) % optionCount;
151-
return index != 0;
152-
}
153-
};
154-
155-
bool useArray = false;
156-
std::vector<FieldInfo> fields;
157-
158-
HeapType operator*() const {
159-
if (useArray) {
160-
return Array(fields[0]);
161-
}
162-
return Struct(std::vector<Field>(fields.begin(), fields.end()));
163-
}
164-
165-
BrandTypeIterator& operator++() {
166-
for (Index i = fields.size(); i > 0; --i) {
167-
if (fields[i - 1].advance()) {
168-
return *this;
169-
}
170-
}
171-
if (useArray) {
172-
useArray = false;
173-
return *this;
174-
}
175-
fields.emplace_back();
176-
useArray = fields.size() == 1;
177-
return *this;
178-
}
179-
};
180-
181103
// Create an adjacency list with edges from supertype to subtype and from
182104
// described type to descriptor.
183105
std::vector<std::vector<Index>>

src/wasm-type-shape.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#define wasm_wasm_type_shape_h
1919

2020
#include <functional>
21+
#include <list>
22+
#include <unordered_set>
2123
#include <vector>
2224

2325
#include "wasm-features.h"
@@ -79,4 +81,100 @@ template<> class hash<wasm::RecGroupShape> {
7981

8082
} // namespace std
8183

84+
namespace wasm {
85+
86+
// Provides an infinite sequence of possible brand types, prioritizing those
87+
// with the most compact encoding.
88+
struct BrandTypeIterator {
89+
static constexpr Index optionCount = 18;
90+
static constexpr std::array<Field, optionCount> fieldOptions = {{
91+
Field(Field::i8, Mutable),
92+
Field(Field::i16, Mutable),
93+
Field(Type::i32, Mutable),
94+
Field(Type::i64, Mutable),
95+
Field(Type::f32, Mutable),
96+
Field(Type::f64, Mutable),
97+
Field(Type(HeapType::any, Nullable), Mutable),
98+
Field(Type(HeapType::func, Nullable), Mutable),
99+
Field(Type(HeapType::ext, Nullable), Mutable),
100+
Field(Type(HeapType::none, Nullable), Mutable),
101+
Field(Type(HeapType::nofunc, Nullable), Mutable),
102+
Field(Type(HeapType::noext, Nullable), Mutable),
103+
Field(Type(HeapType::any, NonNullable), Mutable),
104+
Field(Type(HeapType::func, NonNullable), Mutable),
105+
Field(Type(HeapType::ext, NonNullable), Mutable),
106+
Field(Type(HeapType::none, NonNullable), Mutable),
107+
Field(Type(HeapType::nofunc, NonNullable), Mutable),
108+
Field(Type(HeapType::noext, NonNullable), Mutable),
109+
}};
110+
111+
struct FieldInfo {
112+
uint8_t index = 0;
113+
bool immutable = false;
114+
115+
operator Field() const {
116+
auto field = fieldOptions[index];
117+
if (immutable) {
118+
field.mutable_ = Immutable;
119+
}
120+
return field;
121+
}
122+
123+
bool advance() {
124+
if (!immutable) {
125+
immutable = true;
126+
return true;
127+
}
128+
immutable = false;
129+
index = (index + 1) % optionCount;
130+
return index != 0;
131+
}
132+
};
133+
134+
bool useArray = false;
135+
std::vector<FieldInfo> fields;
136+
137+
HeapType operator*() const {
138+
if (useArray) {
139+
return Array(fields[0]);
140+
}
141+
return Struct(std::vector<Field>(fields.begin(), fields.end()));
142+
}
143+
144+
BrandTypeIterator& operator++() {
145+
for (Index i = fields.size(); i > 0; --i) {
146+
if (fields[i - 1].advance()) {
147+
return *this;
148+
}
149+
}
150+
if (useArray) {
151+
useArray = false;
152+
return *this;
153+
}
154+
fields.emplace_back();
155+
useArray = fields.size() == 1;
156+
return *this;
157+
}
158+
};
159+
160+
// A set of unique rec group shapes. Upon inserting a new group of types, if it
161+
// has the same shape as a previously inserted group, the types will be rebuilt
162+
// with an extra brand type at the end of the group that differentiates it from
163+
// previous group.
164+
struct UniqueRecGroups {
165+
std::list<std::vector<HeapType>> groups;
166+
std::unordered_set<RecGroupShape> shapes;
167+
168+
FeatureSet features;
169+
170+
UniqueRecGroups(FeatureSet features) : features(features) {}
171+
172+
// Insert a rec group. If it is already unique, return the original types.
173+
// Otherwise rebuild the group make it unique and return the rebuilt types,
174+
// including the brand.
175+
const std::vector<HeapType>& insert(std::vector<HeapType> group);
176+
};
177+
178+
} // namespace wasm
179+
82180
#endif // wasm_wasm_type_shape_h

src/wasm/wasm-type-shape.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,41 @@ bool ComparableRecGroupShape::operator>(const RecGroupShape& other) const {
370370
return GT == compareComparable(*this, other);
371371
}
372372

373+
const std::vector<HeapType>&
374+
UniqueRecGroups::insert(std::vector<HeapType> types) {
375+
auto& group = *groups.emplace(groups.end(), std::move(types));
376+
if (shapes.emplace(RecGroupShape(group, features)).second) {
377+
// The types are already unique.
378+
return group;
379+
}
380+
// There is a conflict. Find a brand that makes the group unique.
381+
BrandTypeIterator brand;
382+
group.push_back(*brand);
383+
while (!shapes.emplace(RecGroupShape(group, features)).second) {
384+
group.back() = *++brand;
385+
}
386+
// Rebuild the rec group to include the brand. Map the old types (excluding
387+
// the brand) to their corresponding new types to preserve recursions within
388+
// the group.
389+
Index size = group.size();
390+
TypeBuilder builder(size);
391+
std::unordered_map<HeapType, HeapType> newTypes;
392+
for (Index i = 0; i < size - 1; ++i) {
393+
newTypes[group[i]] = builder[i];
394+
}
395+
for (Index i = 0; i < size; ++i) {
396+
builder[i].copy(group[i], [&](HeapType type) {
397+
if (auto newType = newTypes.find(type); newType != newTypes.end()) {
398+
return newType->second;
399+
}
400+
return type;
401+
});
402+
}
403+
builder.createRecGroup(0, size);
404+
group = *builder.build();
405+
return group;
406+
}
407+
373408
} // namespace wasm
374409

375410
namespace std {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.
2+
;; RUN: wasm-opt %s -all --closed-world --signature-pruning --fuzz-exec -S -o - | filecheck %s
3+
4+
(module
5+
;; CHECK: (type $public (func))
6+
7+
;; CHECK: (rec
8+
;; CHECK-NEXT: (type $private (func))
9+
10+
;; CHECK: (type $2 (struct))
11+
12+
;; CHECK: (type $test (func (result i32)))
13+
(type $test (func (result i32)))
14+
15+
(type $public (func))
16+
17+
;; After signature pruning this will be (func), which is the same as $public.
18+
;; We must make sure we keep $private a distinct type.
19+
(type $private (func (param i32)))
20+
21+
;; CHECK: (import "" "" (func $public (type $public)))
22+
(import "" "" (func $public (type $public)))
23+
24+
;; CHECK: (elem declare func $public)
25+
26+
;; CHECK: (export "test" (func $test))
27+
28+
;; CHECK: (func $private (type $private)
29+
;; CHECK-NEXT: (local $0 i32)
30+
;; CHECK-NEXT: (nop)
31+
;; CHECK-NEXT: )
32+
(func $private (type $private) (param $unused i32)
33+
(nop)
34+
)
35+
36+
;; CHECK: (func $test (type $test) (result i32)
37+
;; CHECK-NEXT: (local $0 funcref)
38+
;; CHECK-NEXT: (ref.test (ref $private)
39+
;; CHECK-NEXT: (select (result funcref)
40+
;; CHECK-NEXT: (ref.func $public)
41+
;; CHECK-NEXT: (local.get $0)
42+
;; CHECK-NEXT: (i32.const 1)
43+
;; CHECK-NEXT: )
44+
;; CHECK-NEXT: )
45+
;; CHECK-NEXT: )
46+
(func $test (export "test") (type $test) (result i32)
47+
(local funcref)
48+
;; Test that $private and $public are separate types. This should return 0.
49+
(ref.test (ref $private)
50+
;; Use select to prevent the ref.test from being optimized in
51+
;; finalization.
52+
(select (result funcref)
53+
(ref.func $public)
54+
(local.get 0)
55+
(i32.const 1)
56+
)
57+
)
58+
)
59+
)

0 commit comments

Comments
 (0)