Skip to content

Commit 5495c36

Browse files
authored
[WebAssembly] Misc. refactoring in AsmTypeCheck (NFC) (#107978)
Existing methods in AsmTypeCheck assumes symbol operand is the 0th operand; they take a `MCInst` and take `getOperand(0)` on it. I think passing a `MCOperand` removes this assumption and also is more intuitive. This was motivated by a new `try_table` instruction, whose support is going to be added to AsmTypeCheck soon, which has tag symbol operands in any position, depending on the number and the kinds of catch clauses. This PR changes all methods' signature that assumes the 0th operand is the relevant one, even if it's not the symbol operand. This also adds `getSignature` method, which factors out the common task when getting a `WasmSignature` from a `MCOperand`.
1 parent 5a2071b commit 5495c36

File tree

2 files changed

+62
-41
lines changed

2 files changed

+62
-41
lines changed

llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
112112
return false;
113113
}
114114

115-
bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
115+
bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
116116
wasm::ValType &Type) {
117-
auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
117+
auto Local = static_cast<size_t>(LocalOp.getImm());
118118
if (Local >= LocalTypes.size())
119119
return typeError(ErrorLoc, StringRef("no local type specified for index ") +
120120
std::to_string(Local));
@@ -178,21 +178,21 @@ bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
178178
return false;
179179
}
180180

181-
bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
181+
bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
182182
const MCSymbolRefExpr *&SymRef) {
183-
auto Op = Inst.getOperand(0);
184-
if (!Op.isExpr())
183+
if (!SymOp.isExpr())
185184
return typeError(ErrorLoc, StringRef("expected expression operand"));
186-
SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
185+
SymRef = dyn_cast<MCSymbolRefExpr>(SymOp.getExpr());
187186
if (!SymRef)
188187
return typeError(ErrorLoc, StringRef("expected symbol operand"));
189188
return false;
190189
}
191190

192-
bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
191+
bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc,
192+
const MCOperand &GlobalOp,
193193
wasm::ValType &Type) {
194194
const MCSymbolRefExpr *SymRef;
195-
if (getSymRef(ErrorLoc, Inst, SymRef))
195+
if (getSymRef(ErrorLoc, GlobalOp, SymRef))
196196
return true;
197197
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
198198
switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
@@ -217,10 +217,10 @@ bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
217217
return false;
218218
}
219219

220-
bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
220+
bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCOperand &TableOp,
221221
wasm::ValType &Type) {
222222
const MCSymbolRefExpr *SymRef;
223-
if (getSymRef(ErrorLoc, Inst, SymRef))
223+
if (getSymRef(ErrorLoc, TableOp, SymRef))
224224
return true;
225225
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
226226
if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
@@ -231,6 +231,34 @@ bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
231231
return false;
232232
}
233233

234+
bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
235+
const MCOperand &SigOp,
236+
wasm::WasmSymbolType Type,
237+
const wasm::WasmSignature *&Sig) {
238+
const MCSymbolRefExpr *SymRef = nullptr;
239+
if (getSymRef(ErrorLoc, SigOp, SymRef))
240+
return true;
241+
const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
242+
Sig = WasmSym->getSignature();
243+
244+
if (!Sig || WasmSym->getType() != Type) {
245+
const char *TypeName = nullptr;
246+
switch (Type) {
247+
case wasm::WASM_SYMBOL_TYPE_FUNCTION:
248+
TypeName = "func";
249+
break;
250+
case wasm::WASM_SYMBOL_TYPE_TAG:
251+
TypeName = "tag";
252+
break;
253+
default:
254+
return true;
255+
}
256+
return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
257+
": missing ." + TypeName + "type");
258+
}
259+
return false;
260+
}
261+
234262
bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
235263
// Check the return types.
236264
for (auto RVT : llvm::reverse(ReturnTypes)) {
@@ -252,56 +280,56 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
252280
dumpTypeStack("typechecking " + Name + ": ");
253281
wasm::ValType Type;
254282
if (Name == "local.get") {
255-
if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
283+
if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
256284
return true;
257285
Stack.push_back(Type);
258286
} else if (Name == "local.set") {
259-
if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
287+
if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
260288
return true;
261289
if (popType(ErrorLoc, Type))
262290
return true;
263291
} else if (Name == "local.tee") {
264-
if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
292+
if (getLocal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
265293
return true;
266294
if (popType(ErrorLoc, Type))
267295
return true;
268296
Stack.push_back(Type);
269297
} else if (Name == "global.get") {
270-
if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
298+
if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
271299
return true;
272300
Stack.push_back(Type);
273301
} else if (Name == "global.set") {
274-
if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
302+
if (getGlobal(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
275303
return true;
276304
if (popType(ErrorLoc, Type))
277305
return true;
278306
} else if (Name == "table.get") {
279-
if (getTable(Operands[1]->getStartLoc(), Inst, Type))
307+
if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
280308
return true;
281309
if (popType(ErrorLoc, wasm::ValType::I32))
282310
return true;
283311
Stack.push_back(Type);
284312
} else if (Name == "table.set") {
285-
if (getTable(Operands[1]->getStartLoc(), Inst, Type))
313+
if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
286314
return true;
287315
if (popType(ErrorLoc, Type))
288316
return true;
289317
if (popType(ErrorLoc, wasm::ValType::I32))
290318
return true;
291319
} else if (Name == "table.size") {
292-
if (getTable(Operands[1]->getStartLoc(), Inst, Type))
320+
if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
293321
return true;
294322
Stack.push_back(wasm::ValType::I32);
295323
} else if (Name == "table.grow") {
296-
if (getTable(Operands[1]->getStartLoc(), Inst, Type))
324+
if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
297325
return true;
298326
if (popType(ErrorLoc, wasm::ValType::I32))
299327
return true;
300328
if (popType(ErrorLoc, Type))
301329
return true;
302330
Stack.push_back(wasm::ValType::I32);
303331
} else if (Name == "table.fill") {
304-
if (getTable(Operands[1]->getStartLoc(), Inst, Type))
332+
if (getTable(Operands[1]->getStartLoc(), Inst.getOperand(0), Type))
305333
return true;
306334
if (popType(ErrorLoc, wasm::ValType::I32))
307335
return true;
@@ -352,15 +380,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
352380
return true;
353381
Unreachable = false;
354382
if (Name == "catch") {
355-
const MCSymbolRefExpr *SymRef;
356-
if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
383+
const wasm::WasmSignature *Sig = nullptr;
384+
if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
385+
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
357386
return true;
358-
const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
359-
const auto *Sig = WasmSym->getSignature();
360-
if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
361-
return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
362-
WasmSym->getName() +
363-
": missing .tagtype");
364387
// catch instruction pushes values whose types are specified in the tag's
365388
// "params" part
366389
Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
@@ -383,15 +406,10 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
383406
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
384407
return true;
385408
} else if (Name == "call" || Name == "return_call") {
386-
const MCSymbolRefExpr *SymRef;
387-
if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
388-
return true;
389-
auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
390-
auto Sig = WasmSym->getSignature();
391-
if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
392-
return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
393-
WasmSym->getName() +
394-
": missing .functype");
409+
const wasm::WasmSignature *Sig = nullptr;
410+
if (getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
411+
wasm::WASM_SYMBOL_TYPE_FUNCTION, Sig))
412+
return true;
395413
if (checkSig(ErrorLoc, *Sig))
396414
return true;
397415
if (Name == "return_call" && endOfFunction(ErrorLoc))

llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,17 @@ class WebAssemblyAsmTypeCheck final {
4141
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
4242
bool popType(SMLoc ErrorLoc, std::optional<wasm::ValType> EVT);
4343
bool popRefType(SMLoc ErrorLoc);
44-
bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
44+
bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
4545
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
4646
bool checkBr(SMLoc ErrorLoc, size_t Level);
4747
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
48-
bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
48+
bool getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
4949
const MCSymbolRefExpr *&SymRef);
50-
bool getGlobal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
51-
bool getTable(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type);
50+
bool getGlobal(SMLoc ErrorLoc, const MCOperand &GlobalOp,
51+
wasm::ValType &Type);
52+
bool getTable(SMLoc ErrorLoc, const MCOperand &TableOp, wasm::ValType &Type);
53+
bool getSignature(SMLoc ErrorLoc, const MCOperand &SigOp,
54+
wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
5255

5356
public:
5457
WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,

0 commit comments

Comments
 (0)