diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h index 1ab80fb27fa9a..622024db5a8a2 100644 --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -20,6 +20,7 @@ namespace mlir { void registerFromLLVMIRTranslation(); void registerFromSPIRVTranslation(); +void registerFromWasmTranslation(); void registerToCppTranslation(); void registerToLLVMIRTranslation(); void registerToSPIRVTranslation(); @@ -36,6 +37,7 @@ inline void registerAllTranslations() { registerFromLLVMIRTranslation(); registerFromSPIRVTranslation(); registerIRDLToCppTranslation(); + registerFromWasmTranslation(); registerToCppTranslation(); registerToLLVMIRTranslation(); registerToSPIRVTranslation(); diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h new file mode 100644 index 0000000000000..3280432b5f038 --- /dev/null +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -0,0 +1,74 @@ +//===- WasmBinaryEncoding.h - Byte encodings for Wasm binary format ===----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Define various flags used to encode instructions, types, etc. in +// WebAssembly binary format. +// +// These encodings are defined in the WebAssembly binary format specification. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TARGET_WASMBINARYENCODING +#define MLIR_TARGET_WASMBINARYENCODING + +#include + +namespace mlir { +struct WasmBinaryEncoding { + /// Byte encodings for Wasm instructions. + struct OpCode { + // Locals, globals, constants. + static constexpr std::byte constI32{0x41}; + static constexpr std::byte constI64{0x42}; + static constexpr std::byte constFP32{0x43}; + static constexpr std::byte constFP64{0x44}; + }; + + /// Byte encodings of types in Wasm binaries + struct Type { + static constexpr std::byte emptyBlockType{0x40}; + static constexpr std::byte funcType{0x60}; + static constexpr std::byte externRef{0x6F}; + static constexpr std::byte funcRef{0x70}; + static constexpr std::byte v128{0x7B}; + static constexpr std::byte f64{0x7C}; + static constexpr std::byte f32{0x7D}; + static constexpr std::byte i64{0x7E}; + static constexpr std::byte i32{0x7F}; + }; + + /// Byte encodings of Wasm imports. + struct Import { + static constexpr std::byte typeID{0x00}; + static constexpr std::byte tableType{0x01}; + static constexpr std::byte memType{0x02}; + static constexpr std::byte globalType{0x03}; + }; + + /// Byte encodings for Wasm limits. + struct LimitHeader { + static constexpr std::byte lowLimitOnly{0x00}; + static constexpr std::byte bothLimits{0x01}; + }; + + /// Byte encodings describing the mutability of globals. + struct GlobalMutability { + static constexpr std::byte isConst{0x00}; + static constexpr std::byte isMutable{0x01}; + }; + + /// Byte encodings describing Wasm exports. + struct Export { + static constexpr std::byte function{0x00}; + static constexpr std::byte table{0x01}; + static constexpr std::byte memory{0x02}; + static constexpr std::byte global{0x03}; + }; + + static constexpr std::byte endByte{0x0B}; +}; +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h new file mode 100644 index 0000000000000..3f05bbecefc8a --- /dev/null +++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h @@ -0,0 +1,31 @@ +//===- WasmImporter.h - Helpers to create WebAssembly emitter ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers to import WebAssembly code using the WebAssembly +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_WASM_WASMIMPORTER_H +#define MLIR_TARGET_WASM_WASMIMPORTER_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir::wasm { + +/// If `source` contains a valid Wasm binary file, this function returns a +/// a ModuleOp containing the representation of the Wasm module encoded in +/// the source file in the `wasmssa` dialect. +OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, + MLIRContext *context); +} // namespace mlir::wasm + +#endif // MLIR_TARGET_WASM_WASMIMPORTER_H diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index 6eb0abc214d38..f0c3ac4d511c1 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(SPIRV) add_subdirectory(LLVMIR) add_subdirectory(LLVM) add_subdirectory(SMTLIB) +add_subdirectory(Wasm) diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt new file mode 100644 index 0000000000000..890fc0ecfbeb6 --- /dev/null +++ b/mlir/lib/Target/Wasm/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIRTargetWasmImport + TranslateRegistration.cpp + TranslateFromWasm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm + + LINK_LIBS PUBLIC + MLIRWasmSSADialect + MLIRIR + MLIRSupport + MLIRTranslateLib +) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp new file mode 100644 index 0000000000000..c23a2915ef18e --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -0,0 +1,1246 @@ +//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the WebAssembly importer. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/Target/Wasm/WasmBinaryEncoding.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LEB128.h" + +#include +#include +#include + +#define DEBUG_TYPE "wasm-translate" + +// Statistics. +STATISTIC(numFunctionSectionItems, "Parsed functions"); +STATISTIC(numGlobalSectionItems, "Parsed globals"); +STATISTIC(numMemorySectionItems, "Parsed memories"); +STATISTIC(numTableSectionItems, "Parsed tables"); + +static_assert(CHAR_BIT == 8, + "This code expects std::byte to be exactly 8 bits"); + +using namespace mlir; +using namespace mlir::wasm; +using namespace mlir::wasmssa; + +namespace { +using section_id_t = uint8_t; +enum struct WasmSectionType : section_id_t { + CUSTOM = 0, + TYPE = 1, + IMPORT = 2, + FUNCTION = 3, + TABLE = 4, + MEMORY = 5, + GLOBAL = 6, + EXPORT = 7, + START = 8, + ELEMENT = 9, + CODE = 10, + DATA = 11, + DATACOUNT = 12 +}; + +constexpr section_id_t highestWasmSectionID{ + static_cast(WasmSectionType::DATACOUNT)}; + +#define APPLY_WASM_SEC_TRANSFORM \ + WASM_SEC_TRANSFORM(CUSTOM) \ + WASM_SEC_TRANSFORM(TYPE) \ + WASM_SEC_TRANSFORM(IMPORT) \ + WASM_SEC_TRANSFORM(FUNCTION) \ + WASM_SEC_TRANSFORM(TABLE) \ + WASM_SEC_TRANSFORM(MEMORY) \ + WASM_SEC_TRANSFORM(GLOBAL) \ + WASM_SEC_TRANSFORM(EXPORT) \ + WASM_SEC_TRANSFORM(START) \ + WASM_SEC_TRANSFORM(ELEMENT) \ + WASM_SEC_TRANSFORM(CODE) \ + WASM_SEC_TRANSFORM(DATA) \ + WASM_SEC_TRANSFORM(DATACOUNT) + +template +constexpr const char *wasmSectionName = ""; + +#define WASM_SEC_TRANSFORM(section) \ + template <> \ + constexpr const char *wasmSectionName = #section; +APPLY_WASM_SEC_TRANSFORM +#undef WASM_SEC_TRANSFORM + +constexpr bool sectionShouldBeUnique(WasmSectionType secType) { + return secType != WasmSectionType::CUSTOM; +} + +template +struct ByteSequence {}; + +/// Template class for representing a byte sequence of only one byte +template +struct UniqueByte : ByteSequence {}; + +constexpr ByteSequence< + WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64, + WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64, + WasmBinaryEncoding::Type::v128> + valueTypesEncodings{}; + +template +constexpr bool isValueOneOf(std::byte value, + ByteSequence = {}) { + return ((value == allowedFlags) | ... | false); +} + +template +constexpr bool isNotIn(std::byte value, ByteSequence = {}) { + return !isValueOneOf(value); +} + +struct GlobalTypeRecord { + Type type; + bool isMutable; +}; + +struct TypeIdxRecord { + size_t id; +}; + +struct SymbolRefContainer { + FlatSymbolRefAttr symbol; +}; + +struct GlobalSymbolRefContainer : SymbolRefContainer { + Type globalType; +}; + +struct FunctionSymbolRefContainer : SymbolRefContainer { + FunctionType functionType; +}; + +using ImportDesc = + std::variant; + +using parsed_inst_t = FailureOr>; + +struct WasmModuleSymbolTables { + SmallVector funcSymbols; + SmallVector globalSymbols; + SmallVector memSymbols; + SmallVector tableSymbols; + SmallVector moduleFuncTypes; + + std::string getNewSymbolName(StringRef prefix, size_t id) const { + return (prefix + Twine{id}).str(); + } + + std::string getNewFuncSymbolName() const { + auto id = funcSymbols.size(); + return getNewSymbolName("func_", id); + } + + std::string getNewGlobalSymbolName() const { + auto id = globalSymbols.size(); + return getNewSymbolName("global_", id); + } + + std::string getNewMemorySymbolName() const { + auto id = memSymbols.size(); + return getNewSymbolName("mem_", id); + } + + std::string getNewTableSymbolName() const { + auto id = tableSymbols.size(); + return getNewSymbolName("table_", id); + } +}; + +class ParserHead; + +/// Wrapper around SmallVector to only allow access as push and pop on the +/// stack. Makes sure that there are no "free accesses" on the stack to preserve +/// its state. +class ValueStack { +private: + struct LabelLevel { + size_t stackIdx; + LabelLevelOpInterface levelOp; + }; + +public: + bool empty() const { return values.empty(); } + + size_t size() const { return values.size(); } + + /// Pops values from the stack because they are being used in an operation. + /// @param operandTypes The list of expected types of the operation, used + /// to know how many values to pop and check if the types match the + /// expectation. + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + /// @return Failure or the vector of popped values. + FailureOr> popOperands(TypeRange operandTypes, + Location *opLoc); + + /// Push the results of an operation to the stack so they can be used in a + /// following operation. + /// @param results The list of results of the operation + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + LogicalResult pushResults(ValueRange results, Location *opLoc); + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// A simple dump function for debugging. + /// Writes output to llvm::dbgs(). + LLVM_DUMP_METHOD void dump() const; +#endif + +private: + SmallVector values; +}; + +using local_val_t = TypedValue; + +class ExpressionParser { +public: + using locals_t = SmallVector; + ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols, + ArrayRef initLocal) + : parser{parser}, symbols{symbols}, locals{initLocal} {} + +private: + template + inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder); + + template + parsed_inst_t + parseConstInst(OpBuilder &builder, + std::enable_if_t> * = nullptr); + + /// This function generates a dispatch tree to associate an opcode with a + /// parser. Parsers are registered by specialising the + /// `parseSpecificInstruction` function for the op code to handle. + /// + /// The dispatcher is generated by recursively creating all possible patterns + /// for an opcode and calling the relevant parser on the leaf. + /// + /// @tparam patternBitSize is the first bit for which the pattern is not fixed + /// + /// @tparam highBitPattern is the fixed pattern that this instance handles for + /// the 8-patternBitSize bits + template + inline parsed_inst_t dispatchToInstParser(std::byte opCode, + OpBuilder &builder) { + static_assert(patternBitSize <= 8, + "PatternBitSize is outside of range of opcode space! " + "(expected at most 8 bits)"); + if constexpr (patternBitSize < 8) { + constexpr std::byte bitSelect{1 << (7 - patternBitSize)}; + constexpr std::byte nextHighBitPatternStem = highBitPattern << 1; + constexpr size_t nextPatternBitSize = patternBitSize + 1; + if ((opCode & bitSelect) != std::byte{0}) + return dispatchToInstParser( + opCode, builder); + return dispatchToInstParser( + opCode, builder); + } else { + return parseSpecificInstruction(builder); + } + } + + struct ParseResultWithInfo { + SmallVector opResults; + std::byte endingByte; + }; + +public: + template + parsed_inst_t parse(OpBuilder &builder, UniqueByte = {}); + + template + FailureOr + parse(OpBuilder &builder, + ByteSequence parsingEndFilters); + + FailureOr> popOperands(TypeRange operandTypes) { + return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); + } + + LogicalResult pushResults(ValueRange results) { + return valueStack.pushResults(results, ¤tOpLoc.value()); + } + +private: + std::optional currentOpLoc; + ParserHead &parser; + WasmModuleSymbolTables const &symbols; + locals_t locals; + ValueStack valueStack; +}; + +class ParserHead { +public: + ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {} + ParserHead(ParserHead &&) = default; + +private: + ParserHead(ParserHead const &other) = default; + +public: + auto getLocation() const { + return FileLineColLoc::get(locName, 0, anchorOffset + offset); + } + + FailureOr consumeNBytes(size_t nBytes) { + LDBG() << "Consume " << nBytes << " bytes"; + LDBG() << " Bytes remaining: " << size(); + LDBG() << " Current offset: " << offset; + if (nBytes > size()) + return emitError(getLocation(), "trying to extract ") + << nBytes << "bytes when only " << size() << "are available"; + + StringRef res = head.slice(offset, offset + nBytes); + offset += nBytes; + LDBG() << " Updated offset (+" << nBytes << "): " << offset; + return res; + } + + FailureOr consumeByte() { + auto res = consumeNBytes(1); + if (failed(res)) + return failure(); + return std::byte{*res->bytes_begin()}; + } + + template + FailureOr parseLiteral(); + + FailureOr parseVectorSize(); + +private: + // TODO: This is equivalent to parseLiteral and could be removed + // if parseLiteral specialization were moved here, but default GCC on Ubuntu + // 22.04 has bug with template specialization in class declaration + inline FailureOr parseUI32(); + inline FailureOr parseI64(); + +public: + FailureOr parseName() { + FailureOr size = parseVectorSize(); + if (failed(size)) + return failure(); + + return consumeNBytes(*size); + } + + FailureOr parseWasmSectionType() { + FailureOr id = consumeByte(); + if (failed(id)) + return failure(); + if (std::to_integer(*id) > highestWasmSectionID) + return emitError(getLocation(), "invalid section ID: ") + << static_cast(*id); + return static_cast(*id); + } + + FailureOr parseLimit(MLIRContext *ctx) { + using WasmLimits = WasmBinaryEncoding::LimitHeader; + FileLineColLoc limitLocation = getLocation(); + FailureOr limitHeader = consumeByte(); + if (failed(limitHeader)) + return failure(); + + if (isNotIn(*limitHeader)) + return emitError(limitLocation, "invalid limit header: ") + << static_cast(*limitHeader); + FailureOr minParse = parseUI32(); + if (failed(minParse)) + return failure(); + std::optional max{std::nullopt}; + if (*limitHeader == WasmLimits::bothLimits) { + FailureOr maxParse = parseUI32(); + if (failed(maxParse)) + return failure(); + max = *maxParse; + } + return LimitType::get(ctx, *minParse, max); + } + + FailureOr parseValueType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr typeEncoding = consumeByte(); + if (failed(typeEncoding)) + return failure(); + switch (*typeEncoding) { + case WasmBinaryEncoding::Type::i32: + return IntegerType::get(ctx, 32); + case WasmBinaryEncoding::Type::i64: + return IntegerType::get(ctx, 64); + case WasmBinaryEncoding::Type::f32: + return Float32Type::get(ctx); + case WasmBinaryEncoding::Type::f64: + return Float64Type::get(ctx); + case WasmBinaryEncoding::Type::v128: + return IntegerType::get(ctx, 128); + case WasmBinaryEncoding::Type::funcRef: + return wasmssa::FuncRefType::get(ctx); + case WasmBinaryEncoding::Type::externRef: + return wasmssa::ExternRefType::get(ctx); + default: + return emitError(typeLoc, "invalid value type encoding: ") + << static_cast(*typeEncoding); + } + } + + FailureOr parseGlobalType(MLIRContext *ctx) { + using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability; + FailureOr typeParsed = parseValueType(ctx); + if (failed(typeParsed)) + return failure(); + FileLineColLoc mutLoc = getLocation(); + FailureOr mutSpec = consumeByte(); + if (failed(mutSpec)) + return failure(); + if (isNotIn(*mutSpec)) + return emitError(mutLoc, "invalid global mutability specifier: ") + << static_cast(*mutSpec); + return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable}; + } + + FailureOr parseResultType(MLIRContext *ctx) { + FailureOr nParamsParsed = parseVectorSize(); + if (failed(nParamsParsed)) + return failure(); + uint32_t nParams = *nParamsParsed; + SmallVector res{}; + res.reserve(nParams); + for (size_t i = 0; i < nParams; ++i) { + FailureOr parsedType = parseValueType(ctx); + if (failed(parsedType)) + return failure(); + res.push_back(*parsedType); + } + return TupleType::get(ctx, res); + } + + FailureOr parseFunctionType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr funcTypeHeader = consumeByte(); + if (failed(funcTypeHeader)) + return failure(); + if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType) + return emitError(typeLoc, "invalid function type header byte. Expecting ") + << std::to_integer(WasmBinaryEncoding::Type::funcType) + << " got " << std::to_integer(*funcTypeHeader); + FailureOr inputTypes = parseResultType(ctx); + if (failed(inputTypes)) + return failure(); + + FailureOr resTypes = parseResultType(ctx); + if (failed(resTypes)) + return failure(); + + return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes()); + } + + FailureOr parseTypeIndex() { + FailureOr res = parseUI32(); + if (failed(res)) + return failure(); + return TypeIdxRecord{*res}; + } + + FailureOr parseTableType(MLIRContext *ctx) { + FailureOr elmTypeParse = parseValueType(ctx); + if (failed(elmTypeParse)) + return failure(); + if (!isWasmRefType(*elmTypeParse)) + return emitError(getLocation(), "invalid element type for table"); + FailureOr limitParse = parseLimit(ctx); + if (failed(limitParse)) + return failure(); + return TableType::get(ctx, *elmTypeParse, *limitParse); + } + + FailureOr parseImportDesc(MLIRContext *ctx) { + FileLineColLoc importLoc = getLocation(); + FailureOr importType = consumeByte(); + auto packager = [](auto parseResult) -> FailureOr { + if (llvm::failed(parseResult)) + return failure(); + return {*parseResult}; + }; + if (failed(importType)) + return failure(); + switch (*importType) { + case WasmBinaryEncoding::Import::typeID: + return packager(parseTypeIndex()); + case WasmBinaryEncoding::Import::tableType: + return packager(parseTableType(ctx)); + case WasmBinaryEncoding::Import::memType: + return packager(parseLimit(ctx)); + case WasmBinaryEncoding::Import::globalType: + return packager(parseGlobalType(ctx)); + default: + return emitError(importLoc, "invalid import type descriptor: ") + << static_cast(*importType); + } + } + + parsed_inst_t parseExpression(OpBuilder &builder, + WasmModuleSymbolTables const &symbols, + ArrayRef locals = {}) { + auto eParser = ExpressionParser{*this, symbols, locals}; + return eParser.parse(builder); + } + + bool end() const { return curHead().empty(); } + + ParserHead copy() const { return *this; } + +private: + StringRef curHead() const { return head.drop_front(offset); } + + FailureOr peek() const { + if (end()) + return emitError( + getLocation(), + "trying to peek at next byte, but input stream is empty"); + return static_cast(curHead().front()); + } + + size_t size() const { return head.size() - offset; } + + StringRef head; + StringAttr locName; + unsigned anchorOffset{0}; + unsigned offset{0}; +}; + +template <> +FailureOr ParserHead::parseLiteral() { + auto bytes = consumeNBytes(4); + if (failed(bytes)) + return failure(); + float result; + std::memcpy(&result, bytes->bytes_begin(), 4); + return result; +} + +template <> +FailureOr ParserHead::parseLiteral() { + auto bytes = consumeNBytes(8); + if (failed(bytes)) + return failure(); + double result; + std::memcpy(&result, bytes->bytes_begin(), 8); + return result; +} + +template <> +FailureOr ParserHead::parseLiteral() { + char const *error = nullptr; + uint32_t res{0}; + unsigned encodingSize{0}; + StringRef src = curHead(); + uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + if (std::isgreater(decoded, std::numeric_limits::max())) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast(decoded); + offset += encodingSize; + return res; +} + +template <> +FailureOr ParserHead::parseLiteral() { + char const *error = nullptr; + int32_t res{0}; + unsigned encodingSize{0}; + StringRef src = curHead(); + int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + if (std::isgreater(decoded, std::numeric_limits::max()) || + std::isgreater(std::numeric_limits::min(), decoded)) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast(decoded); + offset += encodingSize; + return res; +} + +template <> +FailureOr ParserHead::parseLiteral() { + char const *error = nullptr; + unsigned encodingSize{0}; + StringRef src = curHead(); + int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + offset += encodingSize; + return res; +} + +FailureOr ParserHead::parseVectorSize() { + return parseLiteral(); +} + +inline FailureOr ParserHead::parseUI32() { + return parseLiteral(); +} + +inline FailureOr ParserHead::parseI64() { + return parseLiteral(); +} + +template +inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { + return emitError(*currentOpLoc, "unknown instruction opcode: ") + << static_cast(opCode); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void ValueStack::dump() const { + llvm::dbgs() << "================= Wasm ValueStack =======================\n"; + llvm::dbgs() << "size: " << size() << "\n"; + llvm::dbgs() << "" + << "\n"; + // Stack is pushed to via push_back. Therefore the top of the stack is the + // end of the vector. Iterate in reverse so that the first thing we print + // is the top of the stack. + size_t stackSize = size(); + for (size_t idx = 0; idx < stackSize; idx++) { + size_t actualIdx = stackSize - 1 - idx; + llvm::dbgs() << " "; + values[actualIdx].dump(); + } + llvm::dbgs() << "" + << "\n"; + llvm::dbgs() << "=========================================================\n"; +} +#endif + +parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { + LDBG() << "Popping from ValueStack\n" + << " Elements(s) to pop: " << operandTypes.size() << "\n" + << " Current stack size: " << values.size(); + if (operandTypes.size() > values.size()) + return emitError(*opLoc, + "stack doesn't contain enough values. Trying to get ") + << operandTypes.size() << " operands on a stack containing only " + << values.size() << " values."; + size_t stackIdxOffset = values.size() - operandTypes.size(); + SmallVector res{}; + res.reserve(operandTypes.size()); + for (size_t i{0}; i < operandTypes.size(); ++i) { + Value operand = values[i + stackIdxOffset]; + Type stackType = operand.getType(); + if (stackType != operandTypes[i]) + return emitError(*opLoc, "invalid operand type on stack. Expecting ") + << operandTypes[i] << ", value on stack is of type " << stackType + << "."; + LDBG() << " POP: " << operand; + res.push_back(operand); + } + values.resize(values.size() - operandTypes.size()); + LDBG() << " Updated stack size: " << values.size(); + return res; +} + +LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { + LDBG() << "Pushing to ValueStack\n" + << " Elements(s) to push: " << results.size() << "\n" + << " Current stack size: " << values.size(); + for (Value val : results) { + if (!isWasmValueType(val.getType())) + return emitError(*opLoc, "invalid value type on stack: ") + << val.getType(); + LDBG() << " PUSH: " << val; + values.push_back(val); + } + + LDBG() << " Updated stack size: " << values.size(); + return success(); +} + +template +parsed_inst_t ExpressionParser::parse(OpBuilder &builder, + UniqueByte endByte) { + auto res = parse(builder, ByteSequence{}); + if (failed(res)) + return failure(); + return res->opResults; +} + +template +FailureOr +ExpressionParser::parse(OpBuilder &builder, + ByteSequence parsingEndFilters) { + SmallVector res; + for (;;) { + currentOpLoc = parser.getLocation(); + FailureOr opCode = parser.consumeByte(); + if (failed(opCode)) + return failure(); + if (isValueOneOf(*opCode, parsingEndFilters)) + return {{res, *opCode}}; + parsed_inst_t resParsed; + resParsed = dispatchToInstParser(*opCode, builder); + if (failed(resParsed)) + return failure(); + std::swap(res, *resParsed); + if (failed(pushResults(res))) + return failure(); + } +} + +template +inline Type buildLiteralType(OpBuilder &); + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getF32Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getF64Type(); +} + +template >> +struct AttrHolder; + +template +struct AttrHolder>> { + using type = IntegerAttr; +}; + +template +struct AttrHolder>> { + using type = FloatAttr; +}; + +template +using attr_holder_t = typename AttrHolder::type; + +template >> +attr_holder_t buildLiteralAttr(OpBuilder &builder, ValT val) { + return attr_holder_t::get(buildLiteralType(builder), val); +} + +template +parsed_inst_t ExpressionParser::parseConstInst( + OpBuilder &builder, std::enable_if_t> *) { + auto parsedConstant = parser.parseLiteral(); + if (failed(parsedConstant)) + return failure(); + auto constOp = builder.create( + *currentOpLoc, buildLiteralAttr(builder, *parsedConstant)); + return {{constOp.getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) { + return parseConstInst(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) { + return parseConstInst(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) { + return parseConstInst(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) { + return parseConstInst(builder); +} + +class WasmBinaryParser { +private: + struct SectionRegistry { + using section_location_t = StringRef; + + std::array, highestWasmSectionID + 1> + registry; + + template + std::conditional_t, + ArrayRef> + getContentForSection() const { + constexpr auto idx = static_cast(SecType); + if constexpr (sectionShouldBeUnique(SecType)) { + return registry[idx].empty() ? std::nullopt + : std::make_optional(registry[idx][0]); + } else { + return registry[idx]; + } + } + + bool hasSection(WasmSectionType secType) const { + return !registry[static_cast(secType)].empty(); + } + + /// + /// @returns success if registration valid, failure in case registration + /// can't be done (if another section of same type already exist and this + /// section type should only be present once) + /// + LogicalResult registerSection(WasmSectionType secType, + section_location_t location, Location loc) { + if (sectionShouldBeUnique(secType) && hasSection(secType)) + return emitError(loc, + "trying to add a second instance of unique section"); + + registry[static_cast(secType)].push_back(location); + emitRemark(loc, "Adding section with section ID ") + << static_cast(secType); + return success(); + } + + LogicalResult populateFromBody(ParserHead ph) { + while (!ph.end()) { + FileLineColLoc sectionLoc = ph.getLocation(); + FailureOr secType = ph.parseWasmSectionType(); + if (failed(secType)) + return failure(); + + FailureOr secSizeParsed = ph.parseLiteral(); + if (failed(secSizeParsed)) + return failure(); + + uint32_t secSize = *secSizeParsed; + FailureOr sectionContent = ph.consumeNBytes(secSize); + if (failed(sectionContent)) + return failure(); + + LogicalResult registration = + registerSection(*secType, *sectionContent, sectionLoc); + + if (failed(registration)) + return failure(); + } + return success(); + } + }; + + auto getLocation(int offset = 0) const { + return FileLineColLoc::get(srcName, 0, offset); + } + + template + LogicalResult parseSectionItem(ParserHead &, size_t); + + template + LogicalResult parseSection() { + auto secName = std::string{wasmSectionName
}; + auto sectionNameAttr = + StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION"); + unsigned offset = 0; + auto getLocation = [sectionNameAttr, &offset]() { + return FileLineColLoc::get(sectionNameAttr, 0, offset); + }; + auto secContent = registry.getContentForSection
(); + if (!secContent) { + LDBG() << secName << " section is not present in file."; + return success(); + } + + auto secSrc = secContent.value(); + ParserHead ph{secSrc, sectionNameAttr}; + FailureOr nElemsParsed = ph.parseVectorSize(); + if (failed(nElemsParsed)) + return failure(); + uint32_t nElems = *nElemsParsed; + LDBG() << "Starting to parse " << nElems << " items for section " + << secName; + for (size_t i = 0; i < nElems; ++i) { + if (failed(parseSectionItem
(ph, i))) + return failure(); + } + + if (!ph.end()) + return emitError(getLocation(), "unparsed garbage at end of section ") + << secName; + return success(); + } + + /// Handles the registration of a function import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TypeIdxRecord tid) { + using llvm::Twine; + if (tid.id >= symbols.moduleFuncTypes.size()) + return emitError(loc, "invalid type id: ") + << tid.id << ". Only " << symbols.moduleFuncTypes.size() + << " type registration."; + FunctionType type = symbols.moduleFuncTypes[tid.id]; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = + builder.create(loc, symbol, moduleName, importName, type); + symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type}); + return funcOp.verify(); + } + + /// Handles the registration of a memory import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, LimitType limitType) { + std::string symbol = symbols.getNewMemorySymbolName(); + auto memOp = builder.create(loc, symbol, moduleName, + importName, limitType); + symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)}); + return memOp.verify(); + } + + /// Handles the registration of a table import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TableType tableType) { + std::string symbol = symbols.getNewTableSymbolName(); + auto tableOp = builder.create(loc, symbol, moduleName, + importName, tableType); + symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)}); + return tableOp.verify(); + } + + /// Handles the registration of a global variable import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, GlobalTypeRecord globalType) { + std::string symbol = symbols.getNewGlobalSymbolName(); + auto giOp = + builder.create(loc, symbol, moduleName, importName, + globalType.type, globalType.isMutable); + symbols.globalSymbols.push_back( + {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()}); + return giOp.verify(); + } + + // Detect occurence of errors + LogicalResult peekDiag(Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) + isValid = false; + return failure(); + } + +public: + WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx) + : builder{ctx}, ctx{ctx} { + ctx->getDiagEngine().registerHandler( + [this](Diagnostic &diag) { return peekDiag(diag); }); + ctx->loadAllAvailableDialects(); + if (sourceMgr.getNumBuffers() != 1) { + emitError(UnknownLoc::get(ctx), "one source file should be provided"); + return; + } + uint32_t sourceBufId = sourceMgr.getMainFileID(); + StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer(); + srcName = StringAttr::get( + ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier()); + + auto parser = ParserHead{source, srcName}; + auto const wasmHeader = StringRef{"\0asm", 4}; + FileLineColLoc magicLoc = parser.getLocation(); + FailureOr magic = parser.consumeNBytes(wasmHeader.size()); + if (failed(magic) || magic->compare(wasmHeader)) { + emitError(magicLoc, "source file does not contain valid Wasm header."); + return; + } + auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; + FileLineColLoc versionLoc = parser.getLocation(); + FailureOr version = + parser.consumeNBytes(expectedVersionString.size()); + if (failed(version)) + return; + if (version->compare(expectedVersionString)) { + emitError(versionLoc, + "unsupported Wasm version. Only version 1 is supported."); + return; + } + LogicalResult fillRegistry = registry.populateFromBody(parser.copy()); + if (failed(fillRegistry)) + return; + + mOp = builder.create(getLocation()); + builder.setInsertionPointToStart(&mOp.getBodyRegion().front()); + LogicalResult parsingTypes = parseSection(); + if (failed(parsingTypes)) + return; + + LogicalResult parsingImports = parseSection(); + if (failed(parsingImports)) + return; + + firstInternalFuncID = symbols.funcSymbols.size(); + + LogicalResult parsingFunctions = parseSection(); + if (failed(parsingFunctions)) + return; + + LogicalResult parsingTables = parseSection(); + if (failed(parsingTables)) + return; + + LogicalResult parsingMems = parseSection(); + if (failed(parsingMems)) + return; + + LogicalResult parsingExports = parseSection(); + if (failed(parsingExports)) + return; + + // Copy over sizes of containers into statistics. + numFunctionSectionItems = symbols.funcSymbols.size(); + numGlobalSectionItems = symbols.globalSymbols.size(); + numMemorySectionItems = symbols.memSymbols.size(); + numTableSectionItems = symbols.tableSymbols.size(); + } + + ModuleOp getModule() { + if (isValid) + return mOp; + return ModuleOp{}; + } + +private: + mlir::StringAttr srcName; + OpBuilder builder; + WasmModuleSymbolTables symbols; + MLIRContext *ctx; + ModuleOp mOp; + SectionRegistry registry; + size_t firstInternalFuncID{0}; + bool isValid{true}; +}; + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc importLoc = ph.getLocation(); + auto moduleName = ph.parseName(); + if (failed(moduleName)) + return failure(); + + auto importName = ph.parseName(); + if (failed(importName)) + return failure(); + + FailureOr import = ph.parseImportDesc(ctx); + if (failed(import)) + return failure(); + + return std::visit( + [this, importLoc, &moduleName, &importName](auto import) { + return visitImport(importLoc, *moduleName, *importName, import); + }, + *import); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc exportLoc = ph.getLocation(); + + auto exportName = ph.parseName(); + if (failed(exportName)) + return failure(); + + FailureOr opcode = ph.consumeByte(); + if (failed(opcode)) + return failure(); + + FailureOr idx = ph.parseLiteral(); + if (failed(idx)) + return failure(); + + using SymbolRefDesc = std::variant, + SmallVector, + SmallVector>; + + SymbolRefDesc currentSymbolList; + std::string symbolType = ""; + switch (*opcode) { + case WasmBinaryEncoding::Export::function: + symbolType = "function"; + currentSymbolList = symbols.funcSymbols; + break; + case WasmBinaryEncoding::Export::table: + symbolType = "table"; + currentSymbolList = symbols.tableSymbols; + break; + case WasmBinaryEncoding::Export::memory: + symbolType = "memory"; + currentSymbolList = symbols.memSymbols; + break; + case WasmBinaryEncoding::Export::global: + symbolType = "global"; + currentSymbolList = symbols.globalSymbols; + break; + default: + return emitError(exportLoc, "invalid value for export type: ") + << std::to_integer(*opcode); + } + + auto currentSymbol = std::visit( + [&](const auto &list) -> FailureOr { + if (*idx > list.size()) { + emitError( + exportLoc, + llvm::formatv( + "trying to export {0} {1} which is undefined in this scope", + symbolType, *idx)); + return failure(); + } + return list[*idx].symbol; + }, + currentSymbolList); + + if (failed(currentSymbol)) + return failure(); + + Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); + SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); + StringAttr symName = SymbolTable::getSymbolName(op); + return SymbolTable{mOp}.rename(symName, *exportName); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr tableType = ph.parseTableType(ctx); + if (failed(tableType)) + return failure(); + LDBG() << " Parsed table description: " << *tableType; + StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); + auto tableOp = + builder.create(opLocation, symbol.strref(), *tableType); + symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)}); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc opLoc = ph.getLocation(); + auto typeIdxParsed = ph.parseLiteral(); + if (failed(typeIdxParsed)) + return failure(); + uint32_t typeIdx = *typeIdxParsed; + if (typeIdx >= symbols.moduleFuncTypes.size()) + return emitError(getLocation(), "invalid type index: ") << typeIdx; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = + builder.create(opLoc, symbol, symbols.moduleFuncTypes[typeIdx]); + Block *block = funcOp.addEntryBlock(); + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToEnd(block); + builder.create(opLoc); + builder.restoreInsertionPoint(ip); + symbols.funcSymbols.push_back( + {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())}, + symbols.moduleFuncTypes[typeIdx]}); + return funcOp.verify(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FailureOr funcType = ph.parseFunctionType(ctx); + if (failed(funcType)) + return failure(); + LDBG() << "Parsed function type " << *funcType; + symbols.moduleFuncTypes.push_back(*funcType); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr memory = ph.parseLimit(ctx); + if (failed(memory)) + return failure(); + + LDBG() << " Registering memory " << *memory; + std::string symbol = symbols.getNewMemorySymbolName(); + auto memOp = builder.create(opLocation, symbol, *memory); + symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)}); + return success(); +} +} // namespace + +namespace mlir::wasm { +OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, + MLIRContext *context) { + WasmBinaryParser wBN{source, context}; + ModuleOp mOp = wBN.getModule(); + if (mOp) + return {mOp}; + + return {nullptr}; +} +} // namespace mlir::wasm diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp new file mode 100644 index 0000000000000..03b97846d45d3 --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp @@ -0,0 +1,28 @@ +//===- TranslateRegistration.cpp - Register translation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "mlir/Tools/mlir-translate/Translation.h" + +using namespace mlir; + +namespace mlir { +void registerFromWasmTranslation() { + TranslateToMLIRRegistration registration{ + "import-wasm", "Translate WASM to MLIR", + [](llvm::SourceMgr &sourceMgr, + MLIRContext *context) -> OwningOpRef { + return wasm::importWebAssemblyToModule(sourceMgr, context); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }}; +} +} // namespace mlir diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index a4a942de3c9a7..016d5e6f6e914 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -123,7 +123,7 @@ set(MLIR_TEST_DEPENDS tblgen-to-irdl ) if(NOT MLIR_STANDALONE_BUILD) - list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file) + list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file yaml2obj) endif() set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS} diff --git a/mlir/test/Target/Wasm/bad_wasm_version.yaml b/mlir/test/Target/Wasm/bad_wasm_version.yaml new file mode 100644 index 0000000000000..f834afbef679d --- /dev/null +++ b/mlir/test/Target/Wasm/bad_wasm_version.yaml @@ -0,0 +1,8 @@ +# RUN: yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s + +# CHECK: unsupported Wasm version + +--- !WASM +FileHeader: + Version: 0xDEADBEEF +... diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml new file mode 100644 index 0000000000000..b08c2c87abdb3 --- /dev/null +++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml @@ -0,0 +1,13 @@ +# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s + +# CHECK: trying to export function 42 which is undefined in this scope + +--- !WASM +FileHeader: + Version: 0x00000001 +Sections: + - Type: EXPORT + Exports: + - Name: function_export + Kind: FUNCTION + Index: 42 diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir new file mode 100644 index 0000000000000..541dcf3a2d9eb --- /dev/null +++ b/mlir/test/Target/Wasm/import.mlir @@ -0,0 +1,19 @@ +// RUN: yaml2obj %S/inputs/import.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module +(import "my_module" "foo" (func $foo (param i32))) +(import "my_module" "bar" (func $bar (param i32))) +(import "my_module" "table" (table $round 2 funcref)) +(import "my_module" "mem" (memory $mymem 2)) +(import "my_module" "glob" (global $globglob i32)) +(import "my_other_module" "glob_mut" (global $glob_mut (mut i32))) +) +*/ + +// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()} +// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()} +// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa} +// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa, sym_visibility = "nested"} +// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32 +// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32 diff --git a/mlir/test/Target/Wasm/inputs/import.yaml.wasm b/mlir/test/Target/Wasm/inputs/import.yaml.wasm new file mode 100644 index 0000000000000..7c467ff6fbc67 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/import.yaml.wasm @@ -0,0 +1,44 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: [] + - Type: IMPORT + Imports: + - Module: my_module + Field: foo + Kind: FUNCTION + SigIndex: 0 + - Module: my_module + Field: bar + Kind: FUNCTION + SigIndex: 0 + - Module: my_module + Field: table + Kind: TABLE + Table: + Index: 0 + ElemType: FUNCREF + Limits: + Minimum: 0x2 + - Module: my_module + Field: mem + Kind: MEMORY + Memory: + Minimum: 0x2 + - Module: my_module + Field: glob + Kind: GLOBAL + GlobalType: I32 + GlobalMutable: false + - Module: my_other_module + Field: glob_mut + Kind: GLOBAL + GlobalType: I32 + GlobalMutable: true +... diff --git a/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm new file mode 100644 index 0000000000000..f3edf5f2d0cc2 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm @@ -0,0 +1,10 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: MEMORY + Memories: + - Flags: [ HAS_MAX ] + Minimum: 0x0 + Maximum: 0x0 +... diff --git a/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm new file mode 100644 index 0000000000000..fe70fb686df37 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm @@ -0,0 +1,10 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: MEMORY + Memories: + - Flags: [ HAS_MAX ] + Minimum: 0x0 + Maximum: 0x10000 +... diff --git a/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm new file mode 100644 index 0000000000000..8508ce38251a3 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm @@ -0,0 +1,8 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: MEMORY + Memories: + - Minimum: 0x1 +... diff --git a/mlir/test/Target/Wasm/inputs/stats.yaml.wasm b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm new file mode 100644 index 0000000000000..bf577688b3aed --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm @@ -0,0 +1,38 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: TABLE + Tables: + - Index: 0 + ElemType: FUNCREF + Limits: + Minimum: 0x2 + - Type: MEMORY + Memories: + - Flags: [ HAS_MAX ] + Minimum: 0x0 + Maximum: 0x10000 + - Type: GLOBAL + Globals: + - Index: 0 + Type: I32 + Mutable: false + InitExpr: + Opcode: I32_CONST + Value: 10 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 20000B +... diff --git a/mlir/test/Target/Wasm/inputs/table.yaml.wasm b/mlir/test/Target/Wasm/inputs/table.yaml.wasm new file mode 100644 index 0000000000000..387f41820524f --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/table.yaml.wasm @@ -0,0 +1,23 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TABLE + Tables: + - Index: 0 + ElemType: FUNCREF + Limits: + Minimum: 0x2 + - Index: 1 + ElemType: FUNCREF + Limits: + Flags: [ HAS_MAX ] + Minimum: 0x2 + Maximum: 0x4 + - Index: 2 + ElemType: EXTERNREF + Limits: + Flags: [ HAS_MAX ] + Minimum: 0x2 + Maximum: 0x4 +... diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml new file mode 100644 index 0000000000000..2d2954aa32dda --- /dev/null +++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml @@ -0,0 +1,16 @@ +# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s +# CHECK: error: invalid type index: 2 + +--- !WASM +FileHeader: + Version: 0x00000001 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: [] + - Type: FUNCTION + FunctionTypes: + - 2 diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir new file mode 100644 index 0000000000000..2ba5ab50d51fa --- /dev/null +++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir @@ -0,0 +1,7 @@ +// RUN: yaml2obj %S/inputs/memory_min_eq_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module (memory 0 0)) +*/ + +// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir new file mode 100644 index 0000000000000..ebf64189189f8 --- /dev/null +++ b/mlir/test/Target/Wasm/memory_min_max.mlir @@ -0,0 +1,7 @@ +// RUN: yaml2obj %S/inputs/memory_min_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module (memory 0 65536)) +*/ + +// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir new file mode 100644 index 0000000000000..8d8878618bcc0 --- /dev/null +++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir @@ -0,0 +1,7 @@ +// RUN: yaml2obj %S/inputs/memory_min_no_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module (memory 1)) +*/ + +// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa diff --git a/mlir/test/Target/Wasm/missing_header.yaml b/mlir/test/Target/Wasm/missing_header.yaml new file mode 100644 index 0000000000000..a9f812e0a77f8 --- /dev/null +++ b/mlir/test/Target/Wasm/missing_header.yaml @@ -0,0 +1,12 @@ +# RUN: not yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s + +# CHECK: source file does not contain valid Wasm header + +--- !WASM +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: [] +... diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir new file mode 100644 index 0000000000000..b361de3d99f31 --- /dev/null +++ b/mlir/test/Target/Wasm/stats.mlir @@ -0,0 +1,19 @@ +// RUN: yaml2obj %S/inputs/stats.yaml.wasm -o - | mlir-translate --import-wasm -stats 2>&1 | FileCheck %s +// Check that we get the correct stats for a module that has a single +// function, table, memory, and global. +// REQUIRES: asserts + +/* Source code used to create this test: +(module + (type (;0;) (func (param i32) (result i32))) + (func (;0;) (type 0) (param i32) (result i32) + local.get 0) + (table (;0;) 2 funcref) + (memory (;0;) 0 65536) + (global (;0;) i32 (i32.const 10))) +*/ + +// CHECK: 1 wasm-translate - Parsed functions +// CHECK-NEXT: 0 wasm-translate - Parsed globals +// CHECK-NEXT: 1 wasm-translate - Parsed memories +// CHECK-NEXT: 1 wasm-translate - Parsed tables