Skip to content

Commit 1ededd4

Browse files
lforg37Ferdinand LemaireJessica Paquette
committed
[mlir][wasm] Expression parsing mechanism for Wasm importer
--------- Co-authored-by: Ferdinand Lemaire <[email protected]> Co-authored-by: Jessica Paquette <[email protected]>
1 parent 35655b8 commit 1ededd4

File tree

2 files changed

+342
-0
lines changed

2 files changed

+342
-0
lines changed

mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
#include <cstddef>
1717
namespace mlir {
1818
struct WasmBinaryEncoding {
19+
/// Byte encodings for WASM instructions.
20+
struct OpCode {
21+
// Locals, globals, constants.
22+
static constexpr std::byte constI32{0x41};
23+
static constexpr std::byte constI64{0x42};
24+
static constexpr std::byte constFP32{0x43};
25+
static constexpr std::byte constFP64{0x44};
26+
};
27+
1928
/// Byte encodings of types in WASM binaries
2029
struct Type {
2130
static constexpr std::byte emptyBlockType{0x40};

mlir/lib/Target/Wasm/TranslateFromWasm.cpp

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ struct FunctionSymbolRefContainer : SymbolRefContainer {
142142

143143
using ImportDesc = std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
144144

145+
using parsed_inst_t = llvm::FailureOr<llvm::SmallVector<Value>>;
146+
145147
struct WasmModuleSymbolTables {
146148
llvm::SmallVector<FunctionSymbolRefContainer> funcSymbols;
147149
llvm::SmallVector<GlobalSymbolRefContainer> globalSymbols;
@@ -173,6 +175,134 @@ struct WasmModuleSymbolTables {
173175
return getNewSymbolName("table_", id);
174176
}
175177
};
178+
179+
class ParserHead;
180+
181+
/// Wrapper around SmallVector to only allow access as push and pop on the
182+
/// stack. Makes sure that there are no "free accesses" on the stack to preserve
183+
/// its state.
184+
class ValueStack {
185+
private:
186+
struct LabelLevel {
187+
size_t stackIdx;
188+
LabelLevelOpInterface levelOp;
189+
};
190+
public:
191+
bool empty() const { return values.empty(); }
192+
193+
size_t size() const { return values.size(); }
194+
195+
/// Pops values from the stack because they are being used in an operation.
196+
/// @param operandTypes The list of expected types of the operation, used
197+
/// to know how many values to pop and check if the types match the
198+
/// expectation.
199+
/// @param opLoc Location of the caller, used to report accurately the
200+
/// location
201+
/// if an error occurs.
202+
/// @return Failure or the vector of popped values.
203+
llvm::FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes,
204+
Location *opLoc);
205+
206+
/// Push the results of an operation to the stack so they can be used in a
207+
/// following operation.
208+
/// @param results The list of results of the operation
209+
/// @param opLoc Location of the caller, used to report accurately the
210+
/// location
211+
/// if an error occurs.
212+
LogicalResult pushResults(ValueRange results, Location *opLoc);
213+
214+
215+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
216+
/// A simple dump function for debugging.
217+
/// Writes output to llvm::dbgs().
218+
LLVM_DUMP_METHOD void dump() const;
219+
#endif
220+
221+
private:
222+
llvm::SmallVector<Value> values;
223+
};
224+
225+
using local_val_t = TypedValue<wasmssa::LocalRefType>;
226+
227+
class ExpressionParser {
228+
public:
229+
using locals_t = llvm::SmallVector<local_val_t>;
230+
ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
231+
llvm::ArrayRef<local_val_t> initLocal)
232+
: parser{parser}, symbols{symbols}, locals{initLocal} {}
233+
234+
private:
235+
template <std::byte opCode>
236+
inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
237+
238+
template <typename valueT>
239+
parsed_inst_t
240+
parseConstInst(OpBuilder &builder,
241+
std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
242+
243+
244+
/// This function generates a dispatch tree to associate an opcode with a
245+
/// parser. Parsers are registered by specialising the
246+
/// `parseSpecificInstruction` function for the op code to handle.
247+
///
248+
/// The dispatcher is generated by recursively creating all possible patterns
249+
/// for an opcode and calling the relevant parser on the leaf.
250+
///
251+
/// @tparam patternBitSize is the first bit for which the pattern is not fixed
252+
///
253+
/// @tparam highBitPattern is the fixed pattern that this instance handles for
254+
/// the 8-patternBitSize bits
255+
template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
256+
inline parsed_inst_t dispatchToInstParser(std::byte opCode,
257+
OpBuilder &builder) {
258+
static_assert(patternBitSize <= 8,
259+
"PatternBitSize is outside of range of opcode space! "
260+
"(expected at most 8 bits)");
261+
if constexpr (patternBitSize < 8) {
262+
constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
263+
constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
264+
constexpr size_t nextPatternBitSize = patternBitSize + 1;
265+
if ((opCode & bitSelect) != std::byte{0})
266+
return dispatchToInstParser < nextPatternBitSize,
267+
nextHighBitPatternStem | std::byte{1} > (opCode, builder);
268+
return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
269+
opCode, builder);
270+
} else {
271+
return parseSpecificInstruction<highBitPattern>(builder);
272+
}
273+
}
274+
275+
struct ParseResultWithInfo {
276+
llvm::SmallVector<Value> opResults;
277+
std::byte endingByte;
278+
};
279+
280+
public:
281+
template<std::byte ParseEndByte = WasmBinaryEncoding::endByte>
282+
parsed_inst_t parse(OpBuilder &builder,
283+
UniqueByte<ParseEndByte> = {});
284+
285+
template <std::byte... ExpressionParseEnd>
286+
llvm::FailureOr<ParseResultWithInfo>
287+
parse(OpBuilder &builder,
288+
ByteSequence<ExpressionParseEnd...> parsingEndFilters);
289+
290+
llvm::FailureOr<llvm::SmallVector<Value>>
291+
popOperands(TypeRange operandTypes) {
292+
return valueStack.popOperands(operandTypes, &currentOpLoc.value());
293+
}
294+
295+
LogicalResult pushResults(ValueRange results) {
296+
return valueStack.pushResults(results, &currentOpLoc.value());
297+
}
298+
private:
299+
std::optional<Location> currentOpLoc;
300+
ParserHead &parser;
301+
WasmModuleSymbolTables const &symbols;
302+
locals_t locals;
303+
ValueStack valueStack;
304+
};
305+
176306
class ParserHead {
177307
public:
178308
ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {}
@@ -382,6 +512,14 @@ class ParserHead {
382512
<< static_cast<int>(*importType);
383513
}
384514
}
515+
516+
parsed_inst_t parseExpression(OpBuilder &builder,
517+
WasmModuleSymbolTables const &symbols,
518+
llvm::ArrayRef<local_val_t> locals = {}) {
519+
auto eParser = ExpressionParser{*this, symbols, locals};
520+
return eParser.parse(builder);
521+
}
522+
385523
bool end() const { return curHead().empty(); }
386524

387525
ParserHead copy() const {
@@ -491,6 +629,201 @@ inline llvm::FailureOr<int64_t> ParserHead::parseI64() {
491629
return parseLiteral<int64_t>();
492630
}
493631

632+
template <std::byte opCode>
633+
inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
634+
return emitError(*currentOpLoc, "Unknown instruction opcode: ")
635+
<< static_cast<int>(opCode);
636+
}
637+
638+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
639+
void ValueStack::dump() const {
640+
llvm::dbgs() << "================= Wasm ValueStack =======================\n";
641+
llvm::dbgs() << "size: " << size() << "\n";
642+
llvm::dbgs() << "<Top>"
643+
<< "\n";
644+
// Stack is pushed to via push_back. Therefore the top of the stack is the
645+
// end of the vector. Iterate in reverse so that the first thing we print
646+
// is the top of the stack.
647+
size_t stackSize = size();
648+
for (size_t idx = 0 ; idx < stackSize ;) {
649+
size_t actualIdx = stackSize - 1 - idx;
650+
llvm::dbgs() << " ";
651+
values[actualIdx].dump();
652+
}
653+
llvm::dbgs() << "<Bottom>"
654+
<< "\n";
655+
llvm::dbgs() << "=========================================================\n";
656+
}
657+
#endif
658+
659+
parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
660+
LLVM_DEBUG(llvm::dbgs() << "Popping from ValueStack\n");
661+
LLVM_DEBUG(llvm::dbgs() << " Elements(s) to pop: " << operandTypes.size()
662+
<< "\n");
663+
LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
664+
if (operandTypes.size() > values.size())
665+
return emitError(*opLoc,
666+
"Stack doesn't contain enough values. Trying to get ")
667+
<< operandTypes.size() << " operands on a stack containing only "
668+
<< values.size() << " values.";
669+
size_t stackIdxOffset = values.size() - operandTypes.size();
670+
llvm::SmallVector<Value> res{};
671+
res.reserve(operandTypes.size());
672+
for (size_t i{0}; i < operandTypes.size(); ++i) {
673+
Value operand = values[i + stackIdxOffset];
674+
Type stackType = operand.getType();
675+
if (stackType != operandTypes[i])
676+
return emitError(*opLoc,
677+
"Invalid operand type on stack. Expecting ")
678+
<< operandTypes[i] << ", value on stack is of type " << stackType
679+
<< ".";
680+
LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n");
681+
res.push_back(operand);
682+
}
683+
values.resize(values.size() - operandTypes.size());
684+
LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n");
685+
return res;
686+
}
687+
688+
LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
689+
LLVM_DEBUG(llvm::dbgs() << "Pushing to ValueStack\n");
690+
LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size()
691+
<< "\n");
692+
LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n");
693+
for (auto val : results) {
694+
if (!isWasmValueType(val.getType()))
695+
return emitError(*opLoc, "Invalid value type on stack: ")
696+
<< val.getType();
697+
LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n");
698+
values.push_back(val);
699+
}
700+
701+
LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n");
702+
return success();
703+
}
704+
705+
template<std::byte EndParseByte>
706+
parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte<EndParseByte> endByte) {
707+
auto res = parse(builder, ByteSequence<EndParseByte>{});
708+
if (failed(res))
709+
return failure();
710+
return res->opResults;
711+
}
712+
713+
template <std::byte... ExpressionParseEnd>
714+
llvm::FailureOr<ExpressionParser::ParseResultWithInfo>
715+
ExpressionParser::parse(OpBuilder &builder,
716+
ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
717+
llvm::SmallVector<Value> res;
718+
for (;;) {
719+
currentOpLoc = parser.getLocation();
720+
auto opCode = parser.consumeByte();
721+
if (failed(opCode))
722+
return failure();
723+
if (isValueOneOf(*opCode, parsingEndFilters))
724+
return {{res, *opCode}};
725+
parsed_inst_t resParsed;
726+
resParsed = dispatchToInstParser(*opCode, builder);
727+
if (failed(resParsed))
728+
return failure();
729+
std::swap(res, *resParsed);
730+
if (failed(pushResults(res)))
731+
return failure();
732+
}
733+
}
734+
735+
736+
template <typename T>
737+
inline Type buildLiteralType(OpBuilder &);
738+
739+
template <>
740+
inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
741+
return builder.getI32Type();
742+
}
743+
744+
template <>
745+
inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
746+
return builder.getI64Type();
747+
}
748+
749+
template <>
750+
inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
751+
return builder.getI32Type();
752+
}
753+
754+
template <>
755+
inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
756+
return builder.getI64Type();
757+
}
758+
759+
template <>
760+
inline Type buildLiteralType<float>(OpBuilder &builder) {
761+
return builder.getF32Type();
762+
}
763+
764+
template <>
765+
inline Type buildLiteralType<double>(OpBuilder &builder) {
766+
return builder.getF64Type();
767+
}
768+
769+
template<typename ValT, typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
770+
struct AttrHolder;
771+
772+
template <typename ValT>
773+
struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
774+
using type = IntegerAttr;
775+
};
776+
777+
template <typename ValT>
778+
struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
779+
using type = FloatAttr;
780+
};
781+
782+
template<typename ValT>
783+
using attr_holder_t = typename AttrHolder<ValT>::type;
784+
785+
template <typename ValT,
786+
typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
787+
attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
788+
return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
789+
}
790+
791+
template <typename valueT>
792+
parsed_inst_t ExpressionParser::parseConstInst(
793+
OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
794+
auto parsedConstant = parser.parseLiteral<valueT>();
795+
if (failed(parsedConstant))
796+
return failure();
797+
auto constOp = builder.create<ConstOp>(
798+
*currentOpLoc, buildLiteralAttr<valueT>(builder, *parsedConstant));
799+
return {{constOp.getResult()}};
800+
}
801+
802+
template <>
803+
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
804+
WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
805+
return parseConstInst<int32_t>(builder);
806+
}
807+
808+
template <>
809+
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
810+
WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
811+
return parseConstInst<int64_t>(builder);
812+
}
813+
814+
template <>
815+
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
816+
WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
817+
return parseConstInst<float>(builder);
818+
}
819+
820+
template <>
821+
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
822+
WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
823+
return parseConstInst<double>(builder);
824+
}
825+
826+
494827
class WasmBinaryParser {
495828
private:
496829
struct SectionRegistry {

0 commit comments

Comments
 (0)