Skip to content

Commit 1c6b0a2

Browse files
committed
refactor Set => Expr, as we will need more expressions, not just set values
1 parent 7a63296 commit 1c6b0a2

File tree

1 file changed

+87
-90
lines changed

1 file changed

+87
-90
lines changed

src/passes/Souperify.cpp

Lines changed: 87 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ namespace wasm {
4040
namespace DataFlow {
4141

4242
struct Node {
43-
// A node is either a Set, which represents a wasm IR operation, or something
44-
// new in the Souper IR.
43+
// We reuse the Binaryen IR as much as possible: when things are identical between
44+
// the two IRs, we just create and Expr node and refer to the Binaryen Expression.
45+
// Other node types here are special things from Souper IR that we can't
46+
// represent that way.
4547
// TODO: add more nodes for the differences between the two IRs, like i1.
4648
enum Type {
4749
Var, // an unknown variable number (not to be confused with var/param/local in wasm)
48-
Set, // a register, defined by a SetLocal
50+
Expr, // a value represented by a Binaryen Expression
4951
Const, // a constant value
5052
Phi, // a phi from converging control flow
5153
Cond, // a condition on a block path (pc or blockpc)
@@ -55,14 +57,14 @@ struct Node {
5557

5658
Node(Type type) : type(type) {}
5759

58-
template<Type expected>
59-
bool is() { return type == expected; }
60+
// TODO: the others, if we need them
61+
bool isBad() { return type == Bad; }
6062

6163
union {
6264
// For Var
6365
Index varIndex;
64-
// For Set
65-
SetLocal* set;
66+
// For Expr
67+
Expression* expr;
6668
// For Const
6769
Literal value;
6870
// For Phi
@@ -91,9 +93,9 @@ struct Node {
9193
ret->varIndex = varIndex;
9294
return ret;
9395
}
94-
static Node* makeSet(SetLocal* set) {
95-
Node* ret = new Node(Set);
96-
ret->set = set;
96+
static Node* makeExpr(Expression* expr) {
97+
Node* ret = new Node(Expr);
98+
ret->expr = expr;
9799
return ret;
98100
}
99101
static Node* makeConst(Literal value) {
@@ -136,10 +138,16 @@ struct Node {
136138
}
137139
};
138140

141+
// We only need one canonical bad node. It is never modified.
142+
static Node CanonicalBad(Node::Type::Bad);
143+
139144
// Main logic to generate IR for a function. This is implemented as a
140-
// visitor on the wasm, where visitors return a bool whether the node is
141-
// supported (false === bad).
142-
struct Builder : public Visitor<Builder, bool> {
145+
// visitor on the wasm, where visitors return a Node* that either
146+
// contains the DataFlow IR for that expression, which can be a
147+
// Bad node if not supported, or nullptr if not relevant (we only
148+
// use the return value for internal expressions, that is, the
149+
// value of a set_local or the condition of an if etc).
150+
struct Builder : public Visitor<Builder, Node*> {
143151
// Tracks the state of locals in a control flow path:
144152
// localState[i] = the node whose value it contains
145153
typedef std::vector<Node*> LocalState;
@@ -194,7 +202,8 @@ struct Builder : public Visitor<Builder, bool> {
194202
} else {
195203
node = Node::makeConst(LiteralUtils::makeLiteralZero(func->getLocalType(i)));
196204
}
197-
addNode(node, i);
205+
addNode(node);
206+
localState[i] = node;
198207
}
199208
// Process the function body, generating the rest of the IR.
200209
visit(func->body);
@@ -207,12 +216,6 @@ struct Builder : public Visitor<Builder, bool> {
207216
return node;
208217
}
209218

210-
// Add a new node, for a specific local index.
211-
Node* addNode(Node* node, Index index) {
212-
localState[index] = node;
213-
return addNode(node);
214-
}
215-
216219
// Merge local state for multiple control flow paths
217220
// TODO: more than 2
218221
void merge(const LocalState& aState, const LocalState& bState, Node* condition, Expression* expr, LocalState& out) {
@@ -239,7 +242,7 @@ struct Builder : public Visitor<Builder, bool> {
239242

240243
// Visitors.
241244

242-
bool visitBlock(Block* curr) {
245+
Node* visitBlock(Block* curr) {
243246
// TODO: handle super-deep nesting
244247
// TODO: handle breaks to here
245248
auto* oldParent = parent;
@@ -249,22 +252,14 @@ struct Builder : public Visitor<Builder, bool> {
249252
visit(child);
250253
}
251254
parent = oldParent;
252-
return true;
255+
return nullptr;
253256
}
254-
bool visitIf(If* curr) {
257+
Node* visitIf(If* curr) {
255258
auto* oldParent = parent;
256259
parentMap[curr] = oldParent;
257260
parent = curr;
258261
// Set up the condition.
259-
// TODO: move this const-or-get logic to a helper, we'll need it elsewhere I am quite sure
260-
Node* condition;
261-
if (auto* get = curr->condition->dynCast<GetLocal>()) {
262-
visit(curr->condition);
263-
condition = getNodeMap[get];
264-
} else {
265-
auto* c = curr->condition->cast<wasm::Const>();
266-
condition = addNode(Node::makeConst(c->value));
267-
}
262+
Node* condition = visit(curr->condition);
268263
assert(condition);
269264
// Handle the contents.
270265
auto initialState = localState;
@@ -279,68 +274,64 @@ struct Builder : public Visitor<Builder, bool> {
279274
merge(initialState, afterIfTrueState, condition, curr, localState);
280275
}
281276
parent = oldParent;
282-
return true;
283-
}
284-
bool visitLoop(Loop* curr) { return false; }
285-
bool visitBreak(Break* curr) { return false; }
286-
bool visitSwitch(Switch* curr) { return false; }
287-
bool visitCall(Call* curr) { return false; }
288-
bool visitCallImport(CallImport* curr) { return false; }
289-
bool visitCallIndirect(CallIndirect* curr) { return false; }
290-
bool visitGetLocal(GetLocal* curr) {
277+
return nullptr;
278+
}
279+
Node* visitLoop(Loop* curr) { return &CanonicalBad; }
280+
Node* visitBreak(Break* curr) { return &CanonicalBad; }
281+
Node* visitSwitch(Switch* curr) { return &CanonicalBad; }
282+
Node* visitCall(Call* curr) { return &CanonicalBad; }
283+
Node* visitCallImport(CallImport* curr) { return &CanonicalBad; }
284+
Node* visitCallIndirect(CallIndirect* curr) { return &CanonicalBad; }
285+
Node* visitGetLocal(GetLocal* curr) {
291286
// We now know which IR node this get refers to
292287
auto* node = localState[curr->index];
293288
getNodeMap[curr] = node;
294-
return !node->is<Node::Type::Bad>();
289+
return node;
295290
}
296-
bool visitSetLocal(SetLocal* curr) {
291+
Node* visitSetLocal(SetLocal* curr) {
297292
sets.push_back(curr);
298293
parentMap[curr] = parent;
299294
// If we are doing a copy, just do the copy.
300295
if (auto* get = curr->value->dynCast<GetLocal>()) {
301296
setNodeMap[curr] = localState[curr->index] = localState[get->index];
302-
return true;
297+
return nullptr;
303298
}
304299
// Make a new IR node for the new value here.
305-
if (visit(curr->value)) {
306-
setNodeMap[curr] = addNode(Node::makeSet(curr), curr->index);
307-
return true;
308-
} else {
309-
setNodeMap[curr] = addNode(Node::makeBad(), curr->index);
310-
return false;
311-
}
300+
auto* node = setNodeMap[curr] = visit(curr->value);
301+
localState[curr->index] = node;
302+
return nullptr;
312303
}
313-
bool visitGetGlobal(GetGlobal* curr) {
314-
return false;
304+
Node* visitGetGlobal(GetGlobal* curr) {
305+
return &CanonicalBad;
315306
}
316-
bool visitSetGlobal(SetGlobal* curr) {
317-
return false;
307+
Node* visitSetGlobal(SetGlobal* curr) {
308+
return &CanonicalBad;
318309
}
319-
bool visitLoad(Load* curr) {
320-
return false;
310+
Node* visitLoad(Load* curr) {
311+
return &CanonicalBad;
321312
}
322-
bool visitStore(Store* curr) {
323-
return false;
313+
Node* visitStore(Store* curr) {
314+
return &CanonicalBad;
324315
}
325-
bool visitAtomicRMW(AtomicRMW* curr) {
326-
return false;
316+
Node* visitAtomicRMW(AtomicRMW* curr) {
317+
return &CanonicalBad;
327318
}
328-
bool visitAtomicCmpxchg(AtomicCmpxchg* curr) {
329-
return false;
319+
Node* visitAtomicCmpxchg(AtomicCmpxchg* curr) {
320+
return &CanonicalBad;
330321
}
331-
bool visitAtomicWait(AtomicWait* curr) {
332-
return false;
322+
Node* visitAtomicWait(AtomicWait* curr) {
323+
return &CanonicalBad;
333324
}
334-
bool visitAtomicWake(AtomicWake* curr) {
335-
return false;
325+
Node* visitAtomicWake(AtomicWake* curr) {
326+
return &CanonicalBad;
336327
}
337-
bool visitConst(Const* curr) {
338-
return true;
328+
Node* visitConst(Const* curr) {
329+
return addNode(Node::makeConst(curr->value));
339330
}
340-
bool visitUnary(Unary* curr) {
341-
return false;
331+
Node* visitUnary(Unary* curr) {
332+
return &CanonicalBad;
342333
}
343-
bool visitBinary(Binary *curr) {
334+
Node* visitBinary(Binary *curr) {
344335
// First, check if we support this op.
345336
switch (curr->op) {
346337
case AddInt32:
@@ -394,29 +385,35 @@ struct Builder : public Visitor<Builder, bool> {
394385
case GeUInt32:
395386
case GeUInt64: break; // these are ok
396387

397-
default: return false; // anything else is bad
388+
default: return &CanonicalBad; // anything else is bad
398389
}
399390
// Then, check if our children are supported.
400-
return visit(curr->left) && visit(curr->right);
391+
// XXX we drop the return values here. For a Const, we created a node
392+
// that we no longer need.
393+
if (!visit(curr->left)->isBad() && !visit(curr->right)->isBad()) {
394+
return addNode(Node::makeExpr(curr));
395+
} else {
396+
return &CanonicalBad;
397+
}
401398
}
402-
bool visitSelect(Select* curr) {
403-
return false;
399+
Node* visitSelect(Select* curr) {
400+
return &CanonicalBad;
404401
}
405-
bool visitDrop(Drop* curr) {
406-
return false;
402+
Node* visitDrop(Drop* curr) {
403+
return &CanonicalBad;
407404
}
408-
bool visitReturn(Return* curr) {
405+
Node* visitReturn(Return* curr) {
409406
// note we don't need the value (it's a const or a get as we are flattened)
410-
return true;
407+
return nullptr;
411408
}
412-
bool visitHost(Host* curr) {
413-
return false;
409+
Node* visitHost(Host* curr) {
410+
return &CanonicalBad;
414411
}
415-
bool visitNop(Nop* curr) {
416-
return true;
412+
Node* visitNop(Nop* curr) {
413+
return nullptr;
417414
}
418-
bool visitUnreachable(Unreachable* curr) {
419-
return false;
415+
Node* visitUnreachable(Unreachable* curr) {
416+
return &CanonicalBad;
420417
}
421418
};
422419

@@ -453,9 +450,9 @@ struct Trace : public Visitor<Trace> {
453450
case Node::Type::Var: {
454451
break; // nothing more to add
455452
}
456-
case Node::Type::Set: {
453+
case Node::Type::Expr: {
457454
// Add the dependencies.
458-
visit(node->set->value);
455+
visit(node->expr);
459456
break;
460457
}
461458
case Node::Type::Const: {
@@ -584,9 +581,9 @@ struct Printer : public Visitor<Printer> {
584581
std::cout << "%" << indexing[node] << ":" << printType(builder.func->getLocalType(node->varIndex)) << " = var";
585582
break; // nothing more to add
586583
}
587-
case Node::Type::Set: {
584+
case Node::Type::Expr: {
588585
std::cout << "%" << indexing[node] << " = ";
589-
visit(node->set->value);
586+
visit(node->expr);
590587
break;
591588
}
592589
case Node::Type::Const: {

0 commit comments

Comments
 (0)