Skip to content

Commit cdf8139

Browse files
authored
ConstantFieldPropagation: Add a variation that picks between 2 values using RefTest (#6692)
CFP focuses on finding when a field always contains a constant, and then replaces a struct.get with that constant. If we find there are two constant values, then in some cases we can still optimize, if we have a way to pick between them. All we have is the struct.get and its reference, so we must use a ref.test: (struct.get $T x (..ref..)) => (select (..constant1..) (..constant2..) (ref.test $U (..ref..)) ) This is valid if, of all the subtypes of $T, those that pass the test have constant1 in that field, and those that fail the test have constant2. For example, a simple case is where $T has two subtypes, $T is never created itself, and each of the two subtypes has a different constant value. This is a somewhat risky operation, as ref.test is not necessarily cheap. To mitigate that, this is a new pass, --cfp-reftest that is not run by default, and also we only optimize when we can use a ref.test on what we think will be a final type (because ref.test on a final type can be faster in VMs).
1 parent 53712b6 commit cdf8139

File tree

7 files changed

+1723
-16
lines changed

7 files changed

+1723
-16
lines changed

src/passes/ConstantFieldPropagation.cpp

Lines changed: 251 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,30 @@
2323
// write to that field of a different value (even using a subtype of T), then
2424
// anywhere we see a get of that field we can place a ref.func of F.
2525
//
26+
// A variation of this pass also uses ref.test to optimize. This is riskier, as
27+
// adding a ref.test means we are adding a non-trivial amount of work, and
28+
// whether it helps overall depends on subsequent optimizations, so we do not do
29+
// it by default. In this variation, if we inferred a field has exactly two
30+
// possible values, and we can differentiate between them using a ref.test, then
31+
// we do
32+
//
33+
// (struct.get $T x (..ref..))
34+
// =>
35+
// (select
36+
// (..constant1..)
37+
// (..constant2..)
38+
// (ref.test $U (..ref..))
39+
// )
40+
//
41+
// This is valid if, of all the subtypes of $T, those that pass the test have
42+
// constant1 in that field, and those that fail the test have constant2. For
43+
// example, a simple case is where $T has two subtypes, $T is never created
44+
// itself, and each of the two subtypes has a different constant value. (Note
45+
// that we do similar things in e.g. GlobalStructInference, where we turn a
46+
// struct.get into a select, but the risk there is much lower since the
47+
// condition for the select is something like a ref.eq - very cheap - while here
48+
// we emit a ref.test which in general is as expensive as a cast.)
49+
//
2650
// FIXME: This pass assumes a closed world. When we start to allow multi-module
2751
// wasm GC programs we need to check for type escaping.
2852
//
@@ -34,6 +58,7 @@
3458
#include "ir/struct-utils.h"
3559
#include "ir/utils.h"
3660
#include "pass.h"
61+
#include "support/small_vector.h"
3762
#include "wasm-builder.h"
3863
#include "wasm-traversal.h"
3964
#include "wasm.h"
@@ -73,17 +98,30 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
7398
// Only modifies struct.get operations.
7499
bool requiresNonNullableLocalFixups() override { return false; }
75100

101+
// We receive the propagated infos, that is, info about field types in a form
102+
// that takes into account subtypes for quick computation, and also the raw
103+
// subtyping and new infos (information about struct.news).
76104
std::unique_ptr<Pass> create() override {
77-
return std::make_unique<FunctionOptimizer>(infos);
105+
return std::make_unique<FunctionOptimizer>(
106+
propagatedInfos, subTypes, rawNewInfos, refTest);
78107
}
79108

80-
FunctionOptimizer(PCVStructValuesMap& infos) : infos(infos) {}
109+
FunctionOptimizer(const PCVStructValuesMap& propagatedInfos,
110+
const SubTypes& subTypes,
111+
const PCVStructValuesMap& rawNewInfos,
112+
bool refTest)
113+
: propagatedInfos(propagatedInfos), subTypes(subTypes),
114+
rawNewInfos(rawNewInfos), refTest(refTest) {}
81115

82116
void visitStructGet(StructGet* curr) {
83117
auto type = curr->ref->type;
84118
if (type == Type::unreachable) {
85119
return;
86120
}
121+
auto heapType = type.getHeapType();
122+
if (!heapType.isStruct()) {
123+
return;
124+
}
87125

88126
Builder builder(*getModule());
89127

@@ -92,8 +130,8 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
92130
// as if nothing was ever noted for that field.
93131
PossibleConstantValues info;
94132
assert(!info.hasNoted());
95-
auto iter = infos.find(type.getHeapType());
96-
if (iter != infos.end()) {
133+
auto iter = propagatedInfos.find(heapType);
134+
if (iter != propagatedInfos.end()) {
97135
// There is information on this type, fetch it.
98136
info = iter->second[curr->index];
99137
}
@@ -113,25 +151,204 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
113151
return;
114152
}
115153

116-
// If the value is not a constant, then it is unknown and we must give up.
154+
// If the value is not a constant, then it is unknown and we must give up
155+
// on simply applying a constant. However, we can try to use a ref.test, if
156+
// that is allowed.
117157
if (!info.isConstant()) {
158+
if (refTest) {
159+
optimizeUsingRefTest(curr);
160+
}
118161
return;
119162
}
120163

121164
// We can do this! Replace the get with a trap on a null reference using a
122165
// ref.as_non_null (we need to trap as the get would have done so), plus the
123166
// constant value. (Leave it to further optimizations to get rid of the
124167
// ref.)
125-
Expression* value = info.makeExpression(*getModule());
126-
auto field = GCTypeUtils::getField(type, curr->index);
127-
assert(field);
128-
value =
129-
Bits::makePackedFieldGet(value, *field, curr->signed_, *getModule());
168+
auto* value = makeExpression(info, heapType, curr);
130169
replaceCurrent(builder.makeSequence(
131170
builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)), value));
132171
changed = true;
133172
}
134173

174+
// Given information about a constant value, and the struct type and StructGet
175+
// that reads it, create an expression for that value.
176+
Expression* makeExpression(const PossibleConstantValues& info,
177+
HeapType type,
178+
StructGet* curr) {
179+
auto* value = info.makeExpression(*getModule());
180+
auto field = GCTypeUtils::getField(type, curr->index);
181+
assert(field);
182+
return Bits::makePackedFieldGet(value, *field, curr->signed_, *getModule());
183+
}
184+
185+
void optimizeUsingRefTest(StructGet* curr) {
186+
auto refType = curr->ref->type;
187+
auto refHeapType = refType.getHeapType();
188+
189+
// We only handle immutable fields in this function, as we will be looking
190+
// at |rawNewInfos|. That is, we are trying to see when a type and its
191+
// subtypes have different values (so that we can differentiate between them
192+
// using a ref.test), and those differences are lost in |propagatedInfos|,
193+
// which has propagated to relevant types so that we can do a single check
194+
// to see what value could be there. So we need to use something more
195+
// precise, |rawNewInfos|, which tracks the values written to struct.news,
196+
// where we know the type exactly (unlike with a struct.set). But for that
197+
// reason the field must be immutable, so that it is valid to only look at
198+
// the struct.news. (A more complex flow analysis could do better here, but
199+
// would be far beyond the scope of this pass.)
200+
if (GCTypeUtils::getField(refType, curr->index)->mutable_ == Mutable) {
201+
return;
202+
}
203+
204+
// We seek two possible constant values. For each we track the constant and
205+
// the types that have that constant. For example, if we have types A, B, C
206+
// and A and B have 42 in their field, and C has 1337, then we'd have this:
207+
//
208+
// values = [ { 42, [A, B] }, { 1337, [C] } ];
209+
struct Value {
210+
PossibleConstantValues constant;
211+
// Use a SmallVector as we'll only have 2 Values, and so the stack usage
212+
// here is fixed.
213+
SmallVector<HeapType, 10> types;
214+
215+
// Whether this slot is used. If so, |constant| has a value, and |types|
216+
// is not empty.
217+
bool used() const {
218+
if (constant.hasNoted()) {
219+
assert(!types.empty());
220+
return true;
221+
}
222+
assert(types.empty());
223+
return false;
224+
}
225+
} values[2];
226+
227+
// Handle one of the subtypes of the relevant type. We check what value it
228+
// has for the field, and update |values|. If we hit a problem, we mark us
229+
// as having failed.
230+
auto fail = false;
231+
auto handleType = [&](HeapType type, Index depth) {
232+
if (fail) {
233+
// TODO: Add a mechanism to halt |iterSubTypes| in the middle, as once
234+
// we fail there is no point to further iterating.
235+
return;
236+
}
237+
238+
auto iter = rawNewInfos.find(type);
239+
if (iter == rawNewInfos.end()) {
240+
// This type has no struct.news, so we can ignore it: it is abstract.
241+
return;
242+
}
243+
244+
auto value = iter->second[curr->index];
245+
if (!value.isConstant()) {
246+
// The value here is not constant, so give up entirely.
247+
fail = true;
248+
return;
249+
}
250+
251+
// Consider the constant value compared to previous ones.
252+
for (Index i = 0; i < 2; i++) {
253+
if (!values[i].used()) {
254+
// There is nothing in this slot: place this value there.
255+
values[i].constant = value;
256+
values[i].types.push_back(type);
257+
break;
258+
}
259+
260+
// There is something in this slot. If we have the same value, append.
261+
if (values[i].constant == value) {
262+
values[i].types.push_back(type);
263+
break;
264+
}
265+
266+
// Otherwise, this value is different than values[i], which is fine:
267+
// we can add it as the second value in the next loop iteration - at
268+
// least, we can do that if there is another iteration: If it's already
269+
// the last, we've failed to find only two values.
270+
if (i == 1) {
271+
fail = true;
272+
return;
273+
}
274+
}
275+
};
276+
subTypes.iterSubTypes(refHeapType, handleType);
277+
278+
if (fail) {
279+
return;
280+
}
281+
282+
// We either filled slot 0, or we did not, and if we did not then cannot
283+
// have filled slot 1 after it.
284+
assert(values[0].used() || !values[1].used());
285+
286+
if (!values[1].used()) {
287+
// We did not see two constant values (we might have seen just one, or
288+
// even no constant values at all).
289+
return;
290+
}
291+
292+
// We have exactly two values to pick between. We can pick between those
293+
// values using a single ref.test if the two sets of types are actually
294+
// disjoint. In general we could compute the LUB of each set and see if it
295+
// overlaps with the other, but for efficiency we only want to do this
296+
// optimization if the type we test on is closed/final, since ref.test on a
297+
// final type can be fairly fast (perhaps constant time). We therefore look
298+
// if one of the sets of types contains a single type and it is final, and
299+
// if so then we'll test on it. (However, see a few lines below on how we
300+
// test for finality.)
301+
// TODO: Consider adding a variation on this pass that uses non-final types.
302+
auto isProperTestType = [&](const Value& value) -> std::optional<HeapType> {
303+
auto& types = value.types;
304+
if (types.size() != 1) {
305+
// Too many types.
306+
return {};
307+
}
308+
309+
auto type = types[0];
310+
// Do not test finality using isOpen(), as that may only be applied late
311+
// in the optimization pipeline. We are in closed-world here, so just
312+
// see if there are subtypes in practice (if not, this can be marked as
313+
// final later, and we assume optimistically that it will).
314+
if (!subTypes.getImmediateSubTypes(type).empty()) {
315+
// There are subtypes.
316+
return {};
317+
}
318+
319+
// Success, we can test on this.
320+
return type;
321+
};
322+
323+
// Look for the index in |values| to test on.
324+
Index testIndex;
325+
if (auto test = isProperTestType(values[0])) {
326+
testIndex = 0;
327+
} else if (auto test = isProperTestType(values[1])) {
328+
testIndex = 1;
329+
} else {
330+
// We failed to find a simple way to separate the types.
331+
return;
332+
}
333+
334+
// Success! We can replace the struct.get with a select over the two values
335+
// (and a trap on null) with the proper ref.test.
336+
Builder builder(*getModule());
337+
338+
auto& testIndexTypes = values[testIndex].types;
339+
assert(testIndexTypes.size() == 1);
340+
auto testType = testIndexTypes[0];
341+
342+
auto* nnRef = builder.makeRefAs(RefAsNonNull, curr->ref);
343+
344+
replaceCurrent(builder.makeSelect(
345+
builder.makeRefTest(nnRef, Type(testType, NonNullable)),
346+
makeExpression(values[testIndex].constant, refHeapType, curr),
347+
makeExpression(values[1 - testIndex].constant, refHeapType, curr)));
348+
349+
changed = true;
350+
}
351+
135352
void doWalkFunction(Function* func) {
136353
WalkerPass<PostWalker<FunctionOptimizer>>::doWalkFunction(func);
137354

@@ -143,7 +360,10 @@ struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
143360
}
144361

145362
private:
146-
PCVStructValuesMap& infos;
363+
const PCVStructValuesMap& propagatedInfos;
364+
const SubTypes& subTypes;
365+
const PCVStructValuesMap& rawNewInfos;
366+
const bool refTest;
147367

148368
bool changed = false;
149369
};
@@ -193,6 +413,11 @@ struct ConstantFieldPropagation : public Pass {
193413
// Only modifies struct.get operations.
194414
bool requiresNonNullableLocalFixups() override { return false; }
195415

416+
// Whether we are optimizing using ref.test, see above.
417+
const bool refTest;
418+
419+
ConstantFieldPropagation(bool refTest) : refTest(refTest) {}
420+
196421
void run(Module* module) override {
197422
if (!module->features.hasGC()) {
198423
return;
@@ -214,8 +439,16 @@ struct ConstantFieldPropagation : public Pass {
214439
BoolStructValuesMap combinedCopyInfos;
215440
functionCopyInfos.combineInto(combinedCopyInfos);
216441

442+
// Prepare data we will need later.
217443
SubTypes subTypes(*module);
218444

445+
PCVStructValuesMap rawNewInfos;
446+
if (refTest) {
447+
// The refTest optimizations require the raw new infos (see above), but we
448+
// can skip copying here if we'll never read this.
449+
rawNewInfos = combinedNewInfos;
450+
}
451+
219452
// Handle subtyping. |combinedInfo| so far contains data that represents
220453
// each struct.new and struct.set's operation on the struct type used in
221454
// that instruction. That is, if we do a struct.set to type T, the value was
@@ -288,17 +521,19 @@ struct ConstantFieldPropagation : public Pass {
288521

289522
// Optimize.
290523
// TODO: Skip this if we cannot optimize anything
291-
FunctionOptimizer(combinedInfos).run(runner, module);
292-
293-
// TODO: Actually remove the field from the type, where possible? That might
294-
// be best in another pass.
524+
FunctionOptimizer(combinedInfos, subTypes, rawNewInfos, refTest)
525+
.run(runner, module);
295526
}
296527
};
297528

298529
} // anonymous namespace
299530

300531
Pass* createConstantFieldPropagationPass() {
301-
return new ConstantFieldPropagation();
532+
return new ConstantFieldPropagation(false);
533+
}
534+
535+
Pass* createConstantFieldPropagationRefTestPass() {
536+
return new ConstantFieldPropagation(true);
302537
}
303538

304539
} // namespace wasm

src/passes/pass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ void PassRegistry::registerPasses() {
121121
registerPass("cfp",
122122
"propagate constant struct field values",
123123
createConstantFieldPropagationPass);
124+
registerPass("cfp-reftest",
125+
"propagate constant struct field values, using ref.test",
126+
createConstantFieldPropagationRefTestPass);
124127
registerPass(
125128
"dce", "removes unreachable code", createDeadCodeEliminationPass);
126129
registerPass("dealign",

src/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Pass* createCodeFoldingPass();
3232
Pass* createCodePushingPass();
3333
Pass* createConstHoistingPass();
3434
Pass* createConstantFieldPropagationPass();
35+
Pass* createConstantFieldPropagationRefTestPass();
3536
Pass* createDAEPass();
3637
Pass* createDAEOptimizingPass();
3738
Pass* createDataFlowOptsPass();

test/lit/help/wasm-opt.test

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@
103103
;; CHECK-NEXT: --cfp propagate constant struct field
104104
;; CHECK-NEXT: values
105105
;; CHECK-NEXT:
106+
;; CHECK-NEXT: --cfp-reftest propagate constant struct field
107+
;; CHECK-NEXT: values, using ref.test
108+
;; CHECK-NEXT:
106109
;; CHECK-NEXT: --coalesce-locals reduce # of locals by coalescing
107110
;; CHECK-NEXT:
108111
;; CHECK-NEXT: --coalesce-locals-learning reduce # of locals by coalescing

test/lit/help/wasm2js.test

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
;; CHECK-NEXT: --cfp propagate constant struct field
5858
;; CHECK-NEXT: values
5959
;; CHECK-NEXT:
60+
;; CHECK-NEXT: --cfp-reftest propagate constant struct field
61+
;; CHECK-NEXT: values, using ref.test
62+
;; CHECK-NEXT:
6063
;; CHECK-NEXT: --coalesce-locals reduce # of locals by coalescing
6164
;; CHECK-NEXT:
6265
;; CHECK-NEXT: --coalesce-locals-learning reduce # of locals by coalescing

0 commit comments

Comments
 (0)