From b6cc91e97b401de7724ac17dd0ca24f4e492398b Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Mon, 30 Jun 2025 15:46:25 +0200 Subject: [PATCH 01/14] [mlir][wasm] Adding wasm import target to mlir-translate This commit contains basic parsing infrastructure + base code to parse wasm binary file type section. --------- Co-authored-by: Ferdinand Lemaire Co-authored-by: Jessica Paquette --- mlir/include/mlir/InitAllTranslations.h | 3 +- .../mlir/Target/Wasm/WasmBinaryEncoding.h | 55 ++ mlir/include/mlir/Target/Wasm/WasmImporter.h | 35 + mlir/lib/Target/CMakeLists.txt | 1 + mlir/lib/Target/Wasm/CMakeLists.txt | 13 + mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 797 ++++++++++++++++++ .../lib/Target/Wasm/TranslateRegistration.cpp | 28 + mlir/test/Target/Wasm/bad_wasm_version.yaml | 8 + mlir/test/Target/Wasm/import.mlir | 19 + mlir/test/Target/Wasm/inputs/import.yaml.wasm | 44 + mlir/test/Target/Wasm/inputs/stats.yaml.wasm | 38 + .../Wasm/invalid_function_type_index.yaml | 18 + mlir/test/Target/Wasm/missing_header.yaml | 12 + mlir/test/Target/Wasm/stats.mlir | 19 + 14 files changed, 1089 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h create mode 100644 mlir/include/mlir/Target/Wasm/WasmImporter.h create mode 100644 mlir/lib/Target/Wasm/CMakeLists.txt create mode 100644 mlir/lib/Target/Wasm/TranslateFromWasm.cpp create mode 100644 mlir/lib/Target/Wasm/TranslateRegistration.cpp create mode 100644 mlir/test/Target/Wasm/bad_wasm_version.yaml create mode 100644 mlir/test/Target/Wasm/import.mlir create mode 100644 mlir/test/Target/Wasm/inputs/import.yaml.wasm create mode 100644 mlir/test/Target/Wasm/inputs/stats.yaml.wasm create mode 100644 mlir/test/Target/Wasm/invalid_function_type_index.yaml create mode 100644 mlir/test/Target/Wasm/missing_header.yaml create mode 100644 mlir/test/Target/Wasm/stats.mlir diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h index 1ab80fb27fa9a..cf8f108b88159 100644 --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -17,9 +17,9 @@ #include "mlir/Target/IRDLToCpp/TranslationRegistration.h" namespace mlir { - void registerFromLLVMIRTranslation(); void registerFromSPIRVTranslation(); +void registerFromWasmTranslation(); void registerToCppTranslation(); void registerToLLVMIRTranslation(); void registerToSPIRVTranslation(); @@ -36,6 +36,7 @@ inline void registerAllTranslations() { registerFromLLVMIRTranslation(); registerFromSPIRVTranslation(); registerIRDLToCppTranslation(); + registerFromWasmTranslation(); registerToCppTranslation(); registerToLLVMIRTranslation(); registerToSPIRVTranslation(); diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h new file mode 100644 index 0000000000000..e01193e47fdea --- /dev/null +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -0,0 +1,55 @@ +//===- WasmBinaryEncoding.h - Byte encodings for Wasm binary format ===----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Define encodings for WebAssembly instructions, types, etc from the +// WebAssembly binary format. +// +// Each encoding is defined in the WebAssembly binary specification. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TARGET_WASMBINARYENCODING +#define MLIR_TARGET_WASMBINARYENCODING + +#include +namespace mlir { +struct WasmBinaryEncoding { + /// Byte encodings of types in WASM binaries + struct Type { + static constexpr std::byte emptyBlockType{0x40}; + static constexpr std::byte funcType{0x60}; + static constexpr std::byte externRef{0x6F}; + static constexpr std::byte funcRef{0x70}; + static constexpr std::byte v128{0x7B}; + static constexpr std::byte f64{0x7C}; + static constexpr std::byte f32{0x7D}; + static constexpr std::byte i64{0x7E}; + static constexpr std::byte i32{0x7F}; + }; + + /// Byte encodings of WASM imports. + struct Import { + static constexpr std::byte typeID{0x00}; + static constexpr std::byte tableType{0x01}; + static constexpr std::byte memType{0x02}; + static constexpr std::byte globalType{0x03}; + }; + + /// Byte encodings for WASM limits. + struct LimitHeader { + static constexpr std::byte lowLimitOnly{0x00}; + static constexpr std::byte bothLimits{0x01}; + }; + + /// Byte encodings describing the mutability of globals. + struct GlobalMutability { + static constexpr std::byte isConst{0x00}; + static constexpr std::byte isMutable{0x01}; + }; + +}; +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h new file mode 100644 index 0000000000000..fc7d275353964 --- /dev/null +++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h @@ -0,0 +1,35 @@ +//===- WasmImporter.h - Helpers to create WebAssembly emitter ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers to import WebAssembly code using the WebAssembly +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_WASM_WASMIMPORTER_H +#define MLIR_TARGET_WASM_WASMIMPORTER_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace wasm { + +/// Translates the given operation to C++ code. The operation or operations in +/// the region of 'op' need almost all be in EmitC dialect. The parameter +/// 'declareVariablesAtTop' enforces that all variables for op results and block +/// arguments are declared at the beginning of the function. +/// If parameter 'fileId' is non-empty, then body of `emitc.file` ops +/// with matching id are emitted. +OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext* context); +} // namespace wasm +} // namespace mlir + +#endif // MLIR_TARGET_WASM_WASMIMPORTER_H diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index 6eb0abc214d38..f0c3ac4d511c1 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(SPIRV) add_subdirectory(LLVMIR) add_subdirectory(LLVM) add_subdirectory(SMTLIB) +add_subdirectory(Wasm) diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt new file mode 100644 index 0000000000000..890fc0ecfbeb6 --- /dev/null +++ b/mlir/lib/Target/Wasm/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIRTargetWasmImport + TranslateRegistration.cpp + TranslateFromWasm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm + + LINK_LIBS PUBLIC + MLIRWasmSSADialect + MLIRIR + MLIRSupport + MLIRTranslateLib +) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp new file mode 100644 index 0000000000000..2962cc212f848 --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -0,0 +1,797 @@ +//===- TranslateFromWasm.cpp - Translating to C++ calls -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/Target/Wasm/WasmBinaryEncoding.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LEB128.h" + +#include + +#define DEBUG_TYPE "wasm-translate" + +// Statistics. +STATISTIC(numFunctionSectionItems, "Parsed functions"); +STATISTIC(numGlobalSectionItems, "Parsed globals"); +STATISTIC(numMemorySectionItems, "Parsed memories"); +STATISTIC(numTableSectionItems, "Parsed tables"); + +static_assert(CHAR_BIT == 8, "This code expects std::byte to be exactly 8 bits"); + +using namespace mlir; +using namespace mlir::wasm; +using namespace mlir::wasmssa; + +namespace { +using section_id_t = uint8_t; +enum struct WasmSectionType : section_id_t { + CUSTOM = 0, + TYPE = 1, + IMPORT = 2, + FUNCTION = 3, + TABLE = 4, + MEMORY = 5, + GLOBAL = 6, + EXPORT = 7, + START = 8, + ELEMENT = 9, + CODE = 10, + DATA = 11, + DATACOUNT = 12 +}; + +constexpr section_id_t highestWasmSectionID{ + static_cast(WasmSectionType::DATACOUNT)}; + +#define APPLY_WASM_SEC_TRANSFORM \ + WASM_SEC_TRANSFORM(CUSTOM) \ + WASM_SEC_TRANSFORM(TYPE) \ + WASM_SEC_TRANSFORM(IMPORT) \ + WASM_SEC_TRANSFORM(FUNCTION) \ + WASM_SEC_TRANSFORM(TABLE) \ + WASM_SEC_TRANSFORM(MEMORY) \ + WASM_SEC_TRANSFORM(GLOBAL) \ + WASM_SEC_TRANSFORM(EXPORT) \ + WASM_SEC_TRANSFORM(START) \ + WASM_SEC_TRANSFORM(ELEMENT) \ + WASM_SEC_TRANSFORM(CODE) \ + WASM_SEC_TRANSFORM(DATA) \ + WASM_SEC_TRANSFORM(DATACOUNT) + +template +constexpr const char *wasmSectionName = ""; + +#define WASM_SEC_TRANSFORM(section) \ + template <> \ + constexpr const char *wasmSectionName = #section; +APPLY_WASM_SEC_TRANSFORM +#undef WASM_SEC_TRANSFORM + +constexpr bool sectionShouldBeUnique(WasmSectionType secType) { + return secType != WasmSectionType::CUSTOM; +} + +template +struct ByteSequence{}; + +template +constexpr ByteSequence +operator+(ByteSequence, ByteSequence) { + return {}; +} + +/// Template class for representing a byte sequence of only one byte +template +struct UniqueByte : ByteSequence {}; + +template +constexpr ByteSequence +byteSeqFromIntSeq(std::integer_sequence) { + return {}; +} + +constexpr auto allOpCodes = + byteSeqFromIntSeq(std::make_integer_sequence()); + +constexpr ByteSequence< + WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64, + WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64, + WasmBinaryEncoding::Type::v128> + valueTypesEncodings{}; + +template +constexpr bool isValueOneOf(std::byte value, ByteSequence = {}) { + return ((value == allowedFlags) | ... | false); +} + +template +constexpr bool isNotIn(std::byte value, ByteSequence = {}) { + return !isValueOneOf(value); +} + +struct GlobalTypeRecord { + Type type; + bool isMutable; +}; + +struct TypeIdxRecord { + size_t id; +}; + +struct SymbolRefContainer { + FlatSymbolRefAttr symbol; +}; + +struct GlobalSymbolRefContainer : SymbolRefContainer { + Type globalType; +}; + +struct FunctionSymbolRefContainer : SymbolRefContainer { + FunctionType functionType; +}; + +using ImportDesc = std::variant; + +struct WasmModuleSymbolTables { + llvm::SmallVector funcSymbols; + llvm::SmallVector globalSymbols; + llvm::SmallVector memSymbols; + llvm::SmallVector tableSymbols; + llvm::SmallVector moduleFuncTypes; + + std::string getNewSymbolName(llvm::StringRef prefix, size_t id) const { + return (prefix + llvm::Twine{id}).str(); + } + + std::string getNewFuncSymbolName() const { + auto id = funcSymbols.size(); + return getNewSymbolName("func_", id); + } + + std::string getNewGlobalSymbolName() const { + auto id = globalSymbols.size(); + return getNewSymbolName("global_", id); + } + + std::string getNewMemorySymbolName() const { + auto id = memSymbols.size(); + return getNewSymbolName("mem_", id); + } + + std::string getNewTableSymbolName() const { + auto id = tableSymbols.size(); + return getNewSymbolName("table_", id); + } +}; +class ParserHead { +public: + ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {} + ParserHead(ParserHead &&) = default; +private: + ParserHead(ParserHead const &other) = default; + +public: + auto getLocation() const { + return FileLineColLoc::get(locName, 0, anchorOffset + offset); + } + + llvm::FailureOr consumeNBytes(size_t nBytes) { + LLVM_DEBUG(llvm::dbgs() << "Consume " << nBytes << " bytes\n"); + LLVM_DEBUG(llvm::dbgs() << " Bytes remaining: " << size() << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Current offset: " << offset << "\n"); + if (nBytes > size()) + return emitError(getLocation(), "trying to extract ") + << nBytes << "bytes when only " << size() << "are avilables"; + + auto res = head.slice(offset, offset + nBytes); + offset += nBytes; + LLVM_DEBUG(llvm::dbgs() + << " Updated offset (+" << nBytes << "): " << offset << "\n"); + return res; + } + + llvm::FailureOr consumeByte() { + auto res = consumeNBytes(1); + if (failed(res)) + return failure(); + return std::byte{*res->bytes_begin()}; + } + + template + llvm::FailureOr parseLiteral(); + + llvm::FailureOr parseVectorSize(); + +private: + // TODO: This is equivalent to parseLiteral and could be removed + // if parseLiteral specialization were moved here, but default GCC on Ubuntu + // 22.04 has bug with template specialization in class declaration + inline llvm::FailureOr parseUI32(); + inline llvm::FailureOr parseI64(); + +public: + llvm::FailureOr parseName() { + auto size = parseVectorSize(); + if (failed(size)) + return failure(); + + return consumeNBytes(*size); + } + + llvm::FailureOr parseWasmSectionType() { + auto id = consumeByte(); + if (failed(id)) + return failure(); + if (std::to_integer(*id) > highestWasmSectionID) + return emitError(getLocation(), "Invalid section ID: ") + << static_cast(*id); + return static_cast(*id); + } + + llvm::FailureOr parseLimit(MLIRContext *ctx) { + using WasmLimits = WasmBinaryEncoding::LimitHeader; + auto limitLocation = getLocation(); + auto limitHeader = consumeByte(); + if (failed(limitHeader)) + return failure(); + + if (isNotIn(*limitHeader)) + return emitError(limitLocation, "Invalid limit header: ") + << static_cast(*limitHeader); + auto minParse = parseUI32(); + if (failed(minParse)) + return failure(); + std::optional max{std::nullopt}; + if (*limitHeader == WasmLimits::bothLimits) { + auto maxParse = parseUI32(); + if (failed(maxParse)) + return failure(); + max = *maxParse; + } + return LimitType::get(ctx, *minParse, max); + } + + llvm::FailureOr parseValueType(MLIRContext *ctx) { + auto typeLoc = getLocation(); + auto typeEncoding = consumeByte(); + if (failed(typeEncoding)) + return failure(); + switch (*typeEncoding) { + case WasmBinaryEncoding::Type::i32: + return IntegerType::get(ctx, 32); + case WasmBinaryEncoding::Type::i64: + return IntegerType::get(ctx, 64); + case WasmBinaryEncoding::Type::f32: + return Float32Type::get(ctx); + case WasmBinaryEncoding::Type::f64: + return Float64Type::get(ctx); + case WasmBinaryEncoding::Type::v128: + return IntegerType::get(ctx, 128); + case WasmBinaryEncoding::Type::funcRef: + return wasmssa::FuncRefType::get(ctx); + case WasmBinaryEncoding::Type::externRef: + return wasmssa::ExternRefType::get(ctx); + default: + return emitError(typeLoc, "Invalid value type encoding: ") + << static_cast(*typeEncoding); + } + } + + llvm::FailureOr parseGlobalType(MLIRContext *ctx) { + using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability; + auto typeParsed = parseValueType(ctx); + if (failed(typeParsed)) + return failure(); + auto mutLoc = getLocation(); + auto mutSpec = consumeByte(); + if (failed(mutSpec)) + return failure(); + if (isNotIn(*mutSpec)) + return emitError(mutLoc, "Invalid global mutability specifier: ") + << static_cast(*mutSpec); + return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable}; + } + + llvm::FailureOr parseResultType(MLIRContext *ctx) { + auto nParamsParsed = parseVectorSize(); + if (failed(nParamsParsed)) + return failure(); + auto nParams = *nParamsParsed; + llvm::SmallVector res{}; + res.reserve(nParams); + for (size_t i = 0; i < nParams; ++i) { + auto parsedType = parseValueType(ctx); + if (failed(parsedType)) + return failure(); + res.push_back(*parsedType); + } + return TupleType::get(ctx, res); + } + + llvm::FailureOr parseFunctionType(MLIRContext *ctx) { + auto typeLoc = getLocation(); + auto funcTypeHeader = consumeByte(); + if (failed(funcTypeHeader)) + return failure(); + if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType) + return emitError(typeLoc, "Invalid function type header byte. Expecting ") + << std::to_integer( + WasmBinaryEncoding::Type::funcType) + << " got " << std::to_integer(*funcTypeHeader); + auto inputTypes = parseResultType(ctx); + if (failed(inputTypes)) + return failure(); + + auto resTypes = parseResultType(ctx); + if (failed(resTypes)) + return failure(); + + return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes()); + } + + llvm::FailureOr parseTypeIndex() { + auto res = parseUI32(); + if (failed(res)) + return failure(); + return TypeIdxRecord{*res}; + } + + llvm::FailureOr parseTableType(MLIRContext *ctx) { + auto elmTypeParse = parseValueType(ctx); + if (failed(elmTypeParse)) + return failure(); + if (!isWasmRefType(*elmTypeParse)) + return emitError(getLocation(), "Invalid element type for table"); + auto limitParse = parseLimit(ctx); + if (failed(limitParse)) + return failure(); + return TableType::get(ctx, *elmTypeParse, *limitParse); + } + + llvm::FailureOr parseImportDesc(MLIRContext *ctx) { + auto importLoc = getLocation(); + auto importType = consumeByte(); + auto packager = [](auto parseResult) -> llvm::FailureOr { + if (llvm::failed(parseResult)) + return failure(); + return {*parseResult}; + }; + if (failed(importType)) + return failure(); + switch (*importType) { + case WasmBinaryEncoding::Import::typeID: + return packager(parseTypeIndex()); + case WasmBinaryEncoding::Import::tableType: + return packager(parseTableType(ctx)); + case WasmBinaryEncoding::Import::memType: + return packager(parseLimit(ctx)); + case WasmBinaryEncoding::Import::globalType: + return packager(parseGlobalType(ctx)); + default: + return emitError(importLoc, "Invalid import type descriptor: ") + << static_cast(*importType); + } + } + bool end() const { return curHead().empty(); } + + ParserHead copy() const { + return *this; + } + +private: + llvm::StringRef curHead() const { return head.drop_front(offset); } + + llvm::FailureOr peek() const { + if (end()) + return emitError( + getLocation(), + "trying to peek at next byte, but input stream is empty"); + return static_cast(curHead().front()); + } + + size_t size() const { return head.size() - offset; } + + llvm::StringRef head; + StringAttr locName; + unsigned anchorOffset{0}; + unsigned offset{0}; +}; + +template <> +llvm::FailureOr ParserHead::parseLiteral() { + auto bytes = consumeNBytes(4); + if (failed(bytes)) + return failure(); + float result; + std::memcpy(&result, bytes->bytes_begin(), 4); + return result; +} + +template <> +llvm::FailureOr ParserHead::parseLiteral() { + auto bytes = consumeNBytes(8); + if (failed(bytes)) + return failure(); + double result; + std::memcpy(&result, bytes->bytes_begin(), 8); + return result; +} + +template <> +llvm::FailureOr ParserHead::parseLiteral() { + char const *error = nullptr; + uint32_t res{0}; + unsigned encodingSize{0}; + auto src = curHead(); + auto decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + if (std::isgreater(decoded, std::numeric_limits::max())) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast(decoded); + offset += encodingSize; + return res; +} + +template <> +llvm::FailureOr ParserHead::parseLiteral() { + char const *error = nullptr; + int32_t res{0}; + unsigned encodingSize{0}; + auto src = curHead(); + auto decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + if (std::isgreater(decoded, std::numeric_limits::max()) || + std::isgreater(std::numeric_limits::min(), decoded)) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast(decoded); + offset += encodingSize; + return res; +} + +template <> +llvm::FailureOr ParserHead::parseLiteral() { + char const *error = nullptr; + unsigned encodingSize{0}; + auto src = curHead(); + auto res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + offset += encodingSize; + return res; +} + +llvm::FailureOr ParserHead::parseVectorSize() { + return parseLiteral(); +} + +inline llvm::FailureOr ParserHead::parseUI32() { + return parseLiteral(); +} + +inline llvm::FailureOr ParserHead::parseI64() { + return parseLiteral(); +} + +class WasmBinaryParser { +private: + struct SectionRegistry { + using section_location_t = llvm::StringRef; + + std::array, highestWasmSectionID+1> registry; + + template + std::conditional_t, + llvm::ArrayRef> + getContentForSection() const { + constexpr auto idx = static_cast(SecType); + if constexpr (sectionShouldBeUnique(SecType)) { + return registry[idx].empty() ? std::nullopt + : std::make_optional(registry[idx][0]); + } else { + return registry[idx]; + } + } + + bool hasSection(WasmSectionType secType) const { + return !registry[static_cast(secType)].empty(); + } + + /// + /// @returns success if registration valid, failure in case registration + /// can't be done (if another section of same type already exist and this + /// section type should only be present once) + /// + LogicalResult registerSection(WasmSectionType secType, + section_location_t location, Location loc) { + if (sectionShouldBeUnique(secType) && hasSection(secType)) + return emitError(loc, + "Trying to add a second instance of unique section"); + + registry[static_cast(secType)].push_back(location); + emitRemark(loc, "Adding section with section ID ") + << static_cast(secType); + return success(); + } + + LogicalResult populateFromBody(ParserHead ph) { + while (!ph.end()) { + auto sectionLoc = ph.getLocation(); + auto secType = ph.parseWasmSectionType(); + if (failed(secType)) + return failure(); + + auto secSizeParsed = ph.parseLiteral(); + if (failed(secSizeParsed)) + return failure(); + + auto secSize = *secSizeParsed; + auto sectionContent = ph.consumeNBytes(secSize); + if (failed(sectionContent)) + return failure(); + + auto registration = + registerSection(*secType, *sectionContent, sectionLoc); + + if (failed(registration)) + return failure(); + + } + return success(); + } + }; + + auto getLocation(int offset = 0) const { + return FileLineColLoc::get(srcName, 0, offset); + } + + template + LogicalResult parseSectionItem(ParserHead &, size_t); + + template + LogicalResult parseSection() { + auto secName = std::string{wasmSectionName
}; + auto sectionNameAttr = + StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION"); + unsigned offset = 0; + auto getLocation = [sectionNameAttr, &offset]() { + return FileLineColLoc::get(sectionNameAttr, 0, offset); + }; + auto secContent = registry.getContentForSection
(); + if (!secContent) { + LLVM_DEBUG(llvm::dbgs() << secName << " section is not present in file."); + return success(); + } + + auto secSrc = secContent.value(); + ParserHead ph{secSrc, sectionNameAttr}; + auto nElemsParsed = ph.parseVectorSize(); + if (failed(nElemsParsed)) + return failure(); + auto nElems = *nElemsParsed; + LLVM_DEBUG(llvm::dbgs() << "Starting to parse " << nElems + << " items for section " << secName << ".\n"); + for (size_t i = 0; i < nElems; ++i) { + if (failed(parseSectionItem
(ph, i))) + return failure(); + } + + if (!ph.end()) + return emitError(getLocation(), "Unparsed garbage at end of section ") + << secName; + return success(); + } + + /// Handles the registration of a function import + LogicalResult visitImport(Location loc, llvm::StringRef moduleName, + llvm::StringRef importName, TypeIdxRecord tid) { + using llvm::Twine; + if (tid.id >= symbols.moduleFuncTypes.size()) + return emitError(loc, "Invalid type id: ") + << tid.id << ". Only " << symbols.moduleFuncTypes.size() + << " type registration."; + auto type = symbols.moduleFuncTypes[tid.id]; + auto symbol = symbols.getNewFuncSymbolName(); + auto funcOp = builder.create( + loc, symbol, moduleName, importName, type); + symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type}); + return funcOp.verify(); + } + + /// Handles the registration of a memory import + LogicalResult visitImport(Location loc, llvm::StringRef moduleName, + llvm::StringRef importName, LimitType limitType) { + auto symbol = symbols.getNewMemorySymbolName(); + auto memOp = builder.create(loc, symbol, moduleName, + importName, limitType); + symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)}); + return memOp.verify(); + } + + /// Handles the registration of a table import + LogicalResult visitImport(Location loc, llvm::StringRef moduleName, + llvm::StringRef importName, TableType tableType) { + auto symbol = symbols.getNewTableSymbolName(); + auto tableOp = builder.create(loc, symbol, moduleName, + importName, tableType); + symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)}); + return tableOp.verify(); + } + + /// Handles the registration of a global variable import + LogicalResult visitImport(Location loc, llvm::StringRef moduleName, + llvm::StringRef importName, + GlobalTypeRecord globalType) { + auto symbol = symbols.getNewGlobalSymbolName(); + auto giOp = + builder.create(loc, symbol, moduleName, importName, + globalType.type, globalType.isMutable); + symbols.globalSymbols.push_back({{FlatSymbolRefAttr::get(giOp)}, giOp.getType()}); + return giOp.verify(); + } + +public: + WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx) + : builder{ctx}, ctx{ctx} { + ctx->loadAllAvailableDialects(); + if (sourceMgr.getNumBuffers() != 1) { + emitError(UnknownLoc::get(ctx), "One source file should be provided"); + return; + } + auto sourceBufId = sourceMgr.getMainFileID(); + auto source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer(); + srcName = StringAttr::get( + ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier()); + + auto parser = ParserHead{source, srcName}; + auto const wasmHeader = StringRef{"\0asm", 4}; + auto magicLoc = parser.getLocation(); + auto magic = parser.consumeNBytes(wasmHeader.size()); + if (failed(magic) || magic->compare(wasmHeader)) { + emitError(magicLoc, + "Source file does not contain valid Wasm header."); + return; + } + auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; + auto versionLoc = parser.getLocation(); + auto version = parser.consumeNBytes(expectedVersionString.size()); + if (failed(version)) + return; + if (version->compare(expectedVersionString)) { + emitError(versionLoc, + "Unsupported Wasm version. Only version 1 is supported."); + return; + } + auto fillRegistry = registry.populateFromBody(parser.copy()); + if (failed(fillRegistry)) + return; + + mOp = builder.create(getLocation()); + builder.setInsertionPointToStart( + &mOp.getBodyRegion().front()); + auto parsingTypes = parseSection(); + if (failed(parsingTypes)) + return; + + auto parsingImports = parseSection(); + if (failed(parsingImports)) + return; + + firstInternalFuncID = symbols.funcSymbols.size(); + + auto parsingFunctions = parseSection(); + if (failed(parsingFunctions)) + return; + + + // Copy over sizes of containers into statistics. + numFunctionSectionItems = symbols.funcSymbols.size(); + numGlobalSectionItems = symbols.globalSymbols.size(); + numMemorySectionItems = symbols.memSymbols.size(); + numTableSectionItems = symbols.tableSymbols.size(); + } + + ModuleOp getModule() { return mOp; } + +private: + mlir::StringAttr srcName; + OpBuilder builder; + WasmModuleSymbolTables symbols; + MLIRContext *ctx; + ModuleOp mOp; + SectionRegistry registry; + size_t firstInternalFuncID{0}; +}; + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { + auto importLoc = ph.getLocation(); + auto moduleName = ph.parseName(); + if (failed(moduleName)) + return failure(); + + auto importName = ph.parseName(); + if (failed(importName)) + return failure(); + + auto import = ph.parseImportDesc(ctx); + if (failed(import)) + return failure(); + + return std::visit( + [this, importLoc, &moduleName, &importName](auto import) { + return visitImport(importLoc, *moduleName, *importName, import); + }, + *import); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + auto opLoc = ph.getLocation(); + auto typeIdxParsed = ph.parseLiteral(); + if (failed(typeIdxParsed)) + return failure(); + auto typeIdx = *typeIdxParsed; + if (typeIdx >= symbols.moduleFuncTypes.size()) + return emitError(getLocation(), "Invalid type index: ") << typeIdx; + auto symbol = symbols.getNewFuncSymbolName(); + auto funcOp = + builder.create(opLoc, symbol, symbols.moduleFuncTypes[typeIdx]); + auto *block = funcOp.addEntryBlock(); + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToEnd(block); + builder.create(opLoc); + builder.restoreInsertionPoint(ip); + symbols.funcSymbols.push_back( + {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())}, + symbols.moduleFuncTypes[typeIdx]}); + return funcOp.verify(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + auto funcType = ph.parseFunctionType(ctx); + if (failed(funcType)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "Parsed function type " << *funcType << '\n'); + symbols.moduleFuncTypes.push_back(*funcType); + return success(); +} +} // namespace + +namespace mlir { +namespace wasm { +OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, + MLIRContext *context) { + WasmBinaryParser wBN{source, context}; + auto mOp = wBN.getModule(); + if (mOp) + return {mOp}; + + return {nullptr}; +} +} // namespace wasm +} // namespace mlir diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp new file mode 100644 index 0000000000000..9c0f7702a96aa --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp @@ -0,0 +1,28 @@ +//===- TranslateRegistration.cpp - Register translation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "mlir/Tools/mlir-translate/Translation.h" + + +using namespace mlir; + +namespace mlir { +void registerFromWasmTranslation() { + TranslateToMLIRRegistration registration{ + "import-wasm", "Translate WASM to MLIR", + [](llvm::SourceMgr &sourceMgr, MLIRContext* context) -> OwningOpRef { + return wasm::importWebAssemblyToModule(sourceMgr, context); + }, [](DialectRegistry& registry) { + registry.insert(); + } + }; +} +} // namespace mlir diff --git a/mlir/test/Target/Wasm/bad_wasm_version.yaml b/mlir/test/Target/Wasm/bad_wasm_version.yaml new file mode 100644 index 0000000000000..4fed1d5a3af3c --- /dev/null +++ b/mlir/test/Target/Wasm/bad_wasm_version.yaml @@ -0,0 +1,8 @@ +# RUN: yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s + +# CHECK: Unsupported Wasm version + +--- !WASM +FileHeader: + Version: 0xDEADBEEF +... diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir new file mode 100644 index 0000000000000..541dcf3a2d9eb --- /dev/null +++ b/mlir/test/Target/Wasm/import.mlir @@ -0,0 +1,19 @@ +// RUN: yaml2obj %S/inputs/import.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module +(import "my_module" "foo" (func $foo (param i32))) +(import "my_module" "bar" (func $bar (param i32))) +(import "my_module" "table" (table $round 2 funcref)) +(import "my_module" "mem" (memory $mymem 2)) +(import "my_module" "glob" (global $globglob i32)) +(import "my_other_module" "glob_mut" (global $glob_mut (mut i32))) +) +*/ + +// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()} +// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()} +// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa} +// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa, sym_visibility = "nested"} +// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32 +// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32 diff --git a/mlir/test/Target/Wasm/inputs/import.yaml.wasm b/mlir/test/Target/Wasm/inputs/import.yaml.wasm new file mode 100644 index 0000000000000..7c467ff6fbc67 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/import.yaml.wasm @@ -0,0 +1,44 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: [] + - Type: IMPORT + Imports: + - Module: my_module + Field: foo + Kind: FUNCTION + SigIndex: 0 + - Module: my_module + Field: bar + Kind: FUNCTION + SigIndex: 0 + - Module: my_module + Field: table + Kind: TABLE + Table: + Index: 0 + ElemType: FUNCREF + Limits: + Minimum: 0x2 + - Module: my_module + Field: mem + Kind: MEMORY + Memory: + Minimum: 0x2 + - Module: my_module + Field: glob + Kind: GLOBAL + GlobalType: I32 + GlobalMutable: false + - Module: my_other_module + Field: glob_mut + Kind: GLOBAL + GlobalType: I32 + GlobalMutable: true +... diff --git a/mlir/test/Target/Wasm/inputs/stats.yaml.wasm b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm new file mode 100644 index 0000000000000..bf577688b3aed --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm @@ -0,0 +1,38 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: TABLE + Tables: + - Index: 0 + ElemType: FUNCREF + Limits: + Minimum: 0x2 + - Type: MEMORY + Memories: + - Flags: [ HAS_MAX ] + Minimum: 0x0 + Maximum: 0x10000 + - Type: GLOBAL + Globals: + - Index: 0 + Type: I32 + Mutable: false + InitExpr: + Opcode: I32_CONST + Value: 10 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 20000B +... diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml new file mode 100644 index 0000000000000..961e9cc6e8029 --- /dev/null +++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml @@ -0,0 +1,18 @@ +# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s +# CHECK: error: Invalid type index: 2 + +# FIXME: mlir-translate should not return 0 here. + +--- !WASM +FileHeader: + Version: 0x00000001 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: [] + - Type: FUNCTION + FunctionTypes: + - 2 diff --git a/mlir/test/Target/Wasm/missing_header.yaml b/mlir/test/Target/Wasm/missing_header.yaml new file mode 100644 index 0000000000000..5610f9c5c6e33 --- /dev/null +++ b/mlir/test/Target/Wasm/missing_header.yaml @@ -0,0 +1,12 @@ +# RUN: not yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s + +# CHECK: Source file does not contain valid Wasm header + +--- !WASM +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: [] +... diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir new file mode 100644 index 0000000000000..e68b85d20f67d --- /dev/null +++ b/mlir/test/Target/Wasm/stats.mlir @@ -0,0 +1,19 @@ +// RUN: yaml2obj %S/inputs/stats.yaml.wasm -o - | mlir-translate --import-wasm -stats 2>&1 | FileCheck %s +// Check that we get the correct stats for a module that has a single +// function, table, memory, and global. +// REQUIRES: asserts + +/* Source code used to create this test: +(module + (type (;0;) (func (param i32) (result i32))) + (func (;0;) (type 0) (param i32) (result i32) + local.get 0) + (table (;0;) 2 funcref) + (memory (;0;) 0 65536) + (global (;0;) i32 (i32.const 10))) +*/ + +// CHECK: 1 wasm-translate - Parsed functions +// CHECK-NEXT: 0 wasm-translate - Parsed globals +// CHECK-NEXT: 0 wasm-translate - Parsed memories +// CHECK-NEXT: 0 wasm-translate - Parsed tables From b942297c12f895c2cbfcdd60d18828ddd0438018 Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Mon, 30 Jun 2025 19:11:09 +0200 Subject: [PATCH 02/14] [mlir][wasm] Handling table in Wasm importer --------- Co-authored-by: Ferdinand Lemaire Co-authored-by: Jessica Paquette --- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 18 +++++++++++++++ mlir/test/Target/Wasm/inputs/table.yaml.wasm | 23 ++++++++++++++++++++ mlir/test/Target/Wasm/stats.mlir | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Target/Wasm/inputs/table.yaml.wasm diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 2962cc212f848..23f0cca1a148f 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -701,6 +701,9 @@ class WasmBinaryParser { if (failed(parsingFunctions)) return; + auto parsingTables = parseSection(); + if (failed(parsingTables)) + return; // Copy over sizes of containers into statistics. numFunctionSectionItems = symbols.funcSymbols.size(); @@ -744,6 +747,21 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, size *import); } +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { + auto opLocation = ph.getLocation(); + auto tableType = ph.parseTableType(ctx); + if (failed(tableType)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType + << '\n'); + auto symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); + auto tableOp = builder.create(opLocation, symbol.strref(), *tableType); + symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)}); + return success(); +} + template <> LogicalResult WasmBinaryParser::parseSectionItem(ParserHead &ph, diff --git a/mlir/test/Target/Wasm/inputs/table.yaml.wasm b/mlir/test/Target/Wasm/inputs/table.yaml.wasm new file mode 100644 index 0000000000000..387f41820524f --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/table.yaml.wasm @@ -0,0 +1,23 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TABLE + Tables: + - Index: 0 + ElemType: FUNCREF + Limits: + Minimum: 0x2 + - Index: 1 + ElemType: FUNCREF + Limits: + Flags: [ HAS_MAX ] + Minimum: 0x2 + Maximum: 0x4 + - Index: 2 + ElemType: EXTERNREF + Limits: + Flags: [ HAS_MAX ] + Minimum: 0x2 + Maximum: 0x4 +... diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir index e68b85d20f67d..dc30e95343d5a 100644 --- a/mlir/test/Target/Wasm/stats.mlir +++ b/mlir/test/Target/Wasm/stats.mlir @@ -16,4 +16,4 @@ // CHECK: 1 wasm-translate - Parsed functions // CHECK-NEXT: 0 wasm-translate - Parsed globals // CHECK-NEXT: 0 wasm-translate - Parsed memories -// CHECK-NEXT: 0 wasm-translate - Parsed tables +// CHECK-NEXT: 1 wasm-translate - Parsed tables From 9af8a9d48496e7d41f10136de16706774392f8b4 Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Mon, 30 Jun 2025 19:27:44 +0200 Subject: [PATCH 03/14] [mlir][wasm] Handle memory in Wasm importer --------- Co-authored-by: Ferdinand Lemaire Co-authored-by: Jessica Paquette --- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 20 +++++++++++++++++++ .../Wasm/inputs/memory_min_eq_max.yaml.wasm | 10 ++++++++++ .../Wasm/inputs/memory_min_max.yaml.wasm | 10 ++++++++++ .../Wasm/inputs/memory_min_no_max.yaml.wasm | 8 ++++++++ mlir/test/Target/Wasm/memory_min_eq_max.mlir | 7 +++++++ mlir/test/Target/Wasm/memory_min_max.mlir | 7 +++++++ mlir/test/Target/Wasm/memory_min_no_max.mlir | 7 +++++++ mlir/test/Target/Wasm/stats.mlir | 2 +- 8 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm create mode 100644 mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm create mode 100644 mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm create mode 100644 mlir/test/Target/Wasm/memory_min_eq_max.mlir create mode 100644 mlir/test/Target/Wasm/memory_min_max.mlir create mode 100644 mlir/test/Target/Wasm/memory_min_no_max.mlir diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 23f0cca1a148f..fe58b43d5d24d 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -705,6 +705,11 @@ class WasmBinaryParser { if (failed(parsingTables)) return; + auto parsingMems = parseSection(); + if (failed(parsingMems)) + return; + + // Copy over sizes of containers into statistics. numFunctionSectionItems = symbols.funcSymbols.size(); numGlobalSectionItems = symbols.globalSymbols.size(); @@ -798,6 +803,21 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, symbols.moduleFuncTypes.push_back(*funcType); return success(); } + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { + auto opLocation = ph.getLocation(); + auto memory = ph.parseLimit(ctx); + if (failed(memory)) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << " Registering memory " << *memory << '\n'); + auto symbol = symbols.getNewMemorySymbolName(); + auto memOp = builder.create(opLocation, symbol, *memory); + symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)}); + return success(); +} } // namespace namespace mlir { diff --git a/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm new file mode 100644 index 0000000000000..f3edf5f2d0cc2 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm @@ -0,0 +1,10 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: MEMORY + Memories: + - Flags: [ HAS_MAX ] + Minimum: 0x0 + Maximum: 0x0 +... diff --git a/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm new file mode 100644 index 0000000000000..fe70fb686df37 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm @@ -0,0 +1,10 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: MEMORY + Memories: + - Flags: [ HAS_MAX ] + Minimum: 0x0 + Maximum: 0x10000 +... diff --git a/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm new file mode 100644 index 0000000000000..8508ce38251a3 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm @@ -0,0 +1,8 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: MEMORY + Memories: + - Minimum: 0x1 +... diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir new file mode 100644 index 0000000000000..088e28685d09a --- /dev/null +++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir @@ -0,0 +1,7 @@ +// RUN: yaml2obj %S/inputs/memory_min_eq_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module (memory 0 0)) +*/ + +// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa, sym_name = "mem_0", sym_visibility = "nested"}> : () -> () diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir new file mode 100644 index 0000000000000..16d3468279d42 --- /dev/null +++ b/mlir/test/Target/Wasm/memory_min_max.mlir @@ -0,0 +1,7 @@ +// RUN: yaml2obj %S/inputs/memory_min_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module (memory 0 65536)) +*/ + +// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa, sym_name = "mem_0", sym_visibility = "nested"}> : () -> () diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir new file mode 100644 index 0000000000000..f71cb1098be18 --- /dev/null +++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir @@ -0,0 +1,7 @@ +// RUN: yaml2obj %S/inputs/memory_min_no_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module (memory 1)) +*/ + +// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa, sym_name = "mem_0", sym_visibility = "nested"}> : () -> () diff --git a/mlir/test/Target/Wasm/stats.mlir b/mlir/test/Target/Wasm/stats.mlir index dc30e95343d5a..b361de3d99f31 100644 --- a/mlir/test/Target/Wasm/stats.mlir +++ b/mlir/test/Target/Wasm/stats.mlir @@ -15,5 +15,5 @@ // CHECK: 1 wasm-translate - Parsed functions // CHECK-NEXT: 0 wasm-translate - Parsed globals -// CHECK-NEXT: 0 wasm-translate - Parsed memories +// CHECK-NEXT: 1 wasm-translate - Parsed memories // CHECK-NEXT: 1 wasm-translate - Parsed tables From 35655b8ea01f068160b206f5a916dbf9406e7b40 Mon Sep 17 00:00:00 2001 From: Ferdinand Lemaire Date: Tue, 1 Jul 2025 13:42:45 +0900 Subject: [PATCH 04/14] [mlir][wasm] Handling of export at Wasm importer level -- Co-authored-by: Luc Forget Co-authored-by: Jessica Paquette --- .../mlir/Target/Wasm/WasmBinaryEncoding.h | 9 +++ mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 74 +++++++++++++++++++ .../Wasm/function_export_out_of_scope.yaml | 15 ++++ 3 files changed, 98 insertions(+) create mode 100644 mlir/test/Target/Wasm/function_export_out_of_scope.yaml diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h index e01193e47fdea..f4721d943fe81 100644 --- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -49,6 +49,15 @@ struct WasmBinaryEncoding { static constexpr std::byte isMutable{0x01}; }; + /// Byte encodings describing WASM exports. + struct Export { + static constexpr std::byte function{0x00}; + static constexpr std::byte table{0x01}; + static constexpr std::byte memory{0x02}; + static constexpr std::byte global{0x03}; + }; + + static constexpr std::byte endByte{0x0B}; }; } // namespace mlir diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index fe58b43d5d24d..1b0235e7d6f90 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/Wasm/WasmImporter.h" #include "llvm/ADT/Statistic.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LEB128.h" #include @@ -709,6 +710,9 @@ class WasmBinaryParser { if (failed(parsingMems)) return; + auto parsingExports = parseSection(); + if (failed(parsingExports)) + return; // Copy over sizes of containers into statistics. numFunctionSectionItems = symbols.funcSymbols.size(); @@ -752,6 +756,76 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, size *import); } +template <> +LogicalResult +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + auto exportLoc = ph.getLocation(); + + auto exportName = ph.parseName(); + if (failed(exportName)) + return failure(); + + auto opcode = ph.consumeByte(); + if (failed(opcode)) + return failure(); + + auto idx = ph.parseLiteral(); + if (failed(idx)) + return failure(); + + using SymbolRefDesc = + std::variant, + llvm::SmallVector, + llvm::SmallVector>; + + SymbolRefDesc currentSymbolList; + std::string symbolType = ""; + switch (*opcode) { + case WasmBinaryEncoding::Export::function: + symbolType = "function"; + currentSymbolList = symbols.funcSymbols; + break; + case WasmBinaryEncoding::Export::table: + symbolType = "table"; + currentSymbolList = symbols.tableSymbols; + break; + case WasmBinaryEncoding::Export::memory: + symbolType = "memory"; + currentSymbolList = symbols.memSymbols; + break; + case WasmBinaryEncoding::Export::global: + symbolType = "global"; + currentSymbolList = symbols.globalSymbols; + break; + default: + return emitError(exportLoc, "Invalid value for export type: ") + << std::to_integer(*opcode); + } + + auto currentSymbol = std::visit( + [&](const auto &list) -> FailureOr { + if (*idx > list.size()) { + emitError( + exportLoc, + llvm::formatv( + "Trying to export {0} {1} which is undefined in this scope", + symbolType, *idx)); + return failure(); + } + return list[*idx].symbol; + }, + currentSymbolList); + + if (failed(currentSymbol)) + return failure(); + + Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); + SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); + auto symName = SymbolTable::getSymbolName(op); + return SymbolTable{mOp}.rename(symName, *exportName); +} + template <> LogicalResult WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml new file mode 100644 index 0000000000000..ffb26f563141a --- /dev/null +++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml @@ -0,0 +1,15 @@ +# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s + +# FIXME: The error code here should be nonzero. + +# CHECK: Trying to export function 42 which is undefined in this scope + +--- !WASM +FileHeader: + Version: 0x00000001 +Sections: + - Type: EXPORT + Exports: + - Name: function_export + Kind: FUNCTION + Index: 42 From 1ededd4183e903675aebf09fca3405f7c216e213 Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Thu, 3 Jul 2025 09:49:59 +0900 Subject: [PATCH 05/14] [mlir][wasm] Expression parsing mechanism for Wasm importer --------- Co-authored-by: Ferdinand Lemaire Co-authored-by: Jessica Paquette --- .../mlir/Target/Wasm/WasmBinaryEncoding.h | 9 + mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 333 ++++++++++++++++++ 2 files changed, 342 insertions(+) diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h index f4721d943fe81..a5b124eecbe67 100644 --- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -16,6 +16,15 @@ #include namespace mlir { struct WasmBinaryEncoding { + /// Byte encodings for WASM instructions. + struct OpCode { + // Locals, globals, constants. + static constexpr std::byte constI32{0x41}; + static constexpr std::byte constI64{0x42}; + static constexpr std::byte constFP32{0x43}; + static constexpr std::byte constFP64{0x44}; + }; + /// Byte encodings of types in WASM binaries struct Type { static constexpr std::byte emptyBlockType{0x40}; diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 1b0235e7d6f90..753513eb9d887 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -142,6 +142,8 @@ struct FunctionSymbolRefContainer : SymbolRefContainer { using ImportDesc = std::variant; +using parsed_inst_t = llvm::FailureOr>; + struct WasmModuleSymbolTables { llvm::SmallVector funcSymbols; llvm::SmallVector globalSymbols; @@ -173,6 +175,134 @@ struct WasmModuleSymbolTables { return getNewSymbolName("table_", id); } }; + +class ParserHead; + +/// Wrapper around SmallVector to only allow access as push and pop on the +/// stack. Makes sure that there are no "free accesses" on the stack to preserve +/// its state. +class ValueStack { +private: + struct LabelLevel { + size_t stackIdx; + LabelLevelOpInterface levelOp; + }; +public: + bool empty() const { return values.empty(); } + + size_t size() const { return values.size(); } + + /// Pops values from the stack because they are being used in an operation. + /// @param operandTypes The list of expected types of the operation, used + /// to know how many values to pop and check if the types match the + /// expectation. + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + /// @return Failure or the vector of popped values. + llvm::FailureOr> popOperands(TypeRange operandTypes, + Location *opLoc); + + /// Push the results of an operation to the stack so they can be used in a + /// following operation. + /// @param results The list of results of the operation + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + LogicalResult pushResults(ValueRange results, Location *opLoc); + + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// A simple dump function for debugging. + /// Writes output to llvm::dbgs(). + LLVM_DUMP_METHOD void dump() const; +#endif + +private: + llvm::SmallVector values; +}; + +using local_val_t = TypedValue; + +class ExpressionParser { +public: + using locals_t = llvm::SmallVector; + ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols, + llvm::ArrayRef initLocal) + : parser{parser}, symbols{symbols}, locals{initLocal} {} + +private: + template + inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder); + + template + parsed_inst_t + parseConstInst(OpBuilder &builder, + std::enable_if_t> * = nullptr); + + + /// This function generates a dispatch tree to associate an opcode with a + /// parser. Parsers are registered by specialising the + /// `parseSpecificInstruction` function for the op code to handle. + /// + /// The dispatcher is generated by recursively creating all possible patterns + /// for an opcode and calling the relevant parser on the leaf. + /// + /// @tparam patternBitSize is the first bit for which the pattern is not fixed + /// + /// @tparam highBitPattern is the fixed pattern that this instance handles for + /// the 8-patternBitSize bits + template + inline parsed_inst_t dispatchToInstParser(std::byte opCode, + OpBuilder &builder) { + static_assert(patternBitSize <= 8, + "PatternBitSize is outside of range of opcode space! " + "(expected at most 8 bits)"); + if constexpr (patternBitSize < 8) { + constexpr std::byte bitSelect{1 << (7 - patternBitSize)}; + constexpr std::byte nextHighBitPatternStem = highBitPattern << 1; + constexpr size_t nextPatternBitSize = patternBitSize + 1; + if ((opCode & bitSelect) != std::byte{0}) + return dispatchToInstParser < nextPatternBitSize, + nextHighBitPatternStem | std::byte{1} > (opCode, builder); + return dispatchToInstParser( + opCode, builder); + } else { + return parseSpecificInstruction(builder); + } + } + + struct ParseResultWithInfo { + llvm::SmallVector opResults; + std::byte endingByte; + }; + +public: + template + parsed_inst_t parse(OpBuilder &builder, + UniqueByte = {}); + + template + llvm::FailureOr + parse(OpBuilder &builder, + ByteSequence parsingEndFilters); + + llvm::FailureOr> + popOperands(TypeRange operandTypes) { + return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); + } + + LogicalResult pushResults(ValueRange results) { + return valueStack.pushResults(results, ¤tOpLoc.value()); + } +private: + std::optional currentOpLoc; + ParserHead &parser; + WasmModuleSymbolTables const &symbols; + locals_t locals; + ValueStack valueStack; + }; + class ParserHead { public: ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {} @@ -382,6 +512,14 @@ class ParserHead { << static_cast(*importType); } } + + parsed_inst_t parseExpression(OpBuilder &builder, + WasmModuleSymbolTables const &symbols, + llvm::ArrayRef locals = {}) { + auto eParser = ExpressionParser{*this, symbols, locals}; + return eParser.parse(builder); + } + bool end() const { return curHead().empty(); } ParserHead copy() const { @@ -491,6 +629,201 @@ inline llvm::FailureOr ParserHead::parseI64() { return parseLiteral(); } +template +inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { + return emitError(*currentOpLoc, "Unknown instruction opcode: ") + << static_cast(opCode); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void ValueStack::dump() const { + llvm::dbgs() << "================= Wasm ValueStack =======================\n"; + llvm::dbgs() << "size: " << size() << "\n"; + llvm::dbgs() << "" + << "\n"; + // Stack is pushed to via push_back. Therefore the top of the stack is the + // end of the vector. Iterate in reverse so that the first thing we print + // is the top of the stack. + size_t stackSize = size(); + for (size_t idx = 0 ; idx < stackSize ;) { + size_t actualIdx = stackSize - 1 - idx; + llvm::dbgs() << " "; + values[actualIdx].dump(); + } + llvm::dbgs() << "" + << "\n"; + llvm::dbgs() << "=========================================================\n"; +} +#endif + +parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { + LLVM_DEBUG(llvm::dbgs() << "Popping from ValueStack\n"); + LLVM_DEBUG(llvm::dbgs() << " Elements(s) to pop: " << operandTypes.size() + << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n"); + if (operandTypes.size() > values.size()) + return emitError(*opLoc, + "Stack doesn't contain enough values. Trying to get ") + << operandTypes.size() << " operands on a stack containing only " + << values.size() << " values."; + size_t stackIdxOffset = values.size() - operandTypes.size(); + llvm::SmallVector res{}; + res.reserve(operandTypes.size()); + for (size_t i{0}; i < operandTypes.size(); ++i) { + Value operand = values[i + stackIdxOffset]; + Type stackType = operand.getType(); + if (stackType != operandTypes[i]) + return emitError(*opLoc, + "Invalid operand type on stack. Expecting ") + << operandTypes[i] << ", value on stack is of type " << stackType + << "."; + LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n"); + res.push_back(operand); + } + values.resize(values.size() - operandTypes.size()); + LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n"); + return res; +} + +LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { + LLVM_DEBUG(llvm::dbgs() << "Pushing to ValueStack\n"); + LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size() + << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n"); + for (auto val : results) { + if (!isWasmValueType(val.getType())) + return emitError(*opLoc, "Invalid value type on stack: ") + << val.getType(); + LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n"); + values.push_back(val); + } + + LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n"); + return success(); +} + +template +parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte endByte) { + auto res = parse(builder, ByteSequence{}); + if (failed(res)) + return failure(); + return res->opResults; +} + +template +llvm::FailureOr +ExpressionParser::parse(OpBuilder &builder, + ByteSequence parsingEndFilters) { + llvm::SmallVector res; + for (;;) { + currentOpLoc = parser.getLocation(); + auto opCode = parser.consumeByte(); + if (failed(opCode)) + return failure(); + if (isValueOneOf(*opCode, parsingEndFilters)) + return {{res, *opCode}}; + parsed_inst_t resParsed; + resParsed = dispatchToInstParser(*opCode, builder); + if (failed(resParsed)) + return failure(); + std::swap(res, *resParsed); + if (failed(pushResults(res))) + return failure(); + } +} + + +template +inline Type buildLiteralType(OpBuilder &); + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getF32Type(); +} + +template <> +inline Type buildLiteralType(OpBuilder &builder) { + return builder.getF64Type(); +} + +template>> +struct AttrHolder; + +template +struct AttrHolder>> { + using type = IntegerAttr; +}; + +template +struct AttrHolder>> { + using type = FloatAttr; +}; + +template +using attr_holder_t = typename AttrHolder::type; + +template >> +attr_holder_t buildLiteralAttr(OpBuilder &builder, ValT val) { + return attr_holder_t::get(buildLiteralType(builder), val); +} + +template +parsed_inst_t ExpressionParser::parseConstInst( + OpBuilder &builder, std::enable_if_t> *) { + auto parsedConstant = parser.parseLiteral(); + if (failed(parsedConstant)) + return failure(); + auto constOp = builder.create( + *currentOpLoc, buildLiteralAttr(builder, *parsedConstant)); + return {{constOp.getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) { + return parseConstInst(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) { + return parseConstInst(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) { + return parseConstInst(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) { + return parseConstInst(builder); +} + + class WasmBinaryParser { private: struct SectionRegistry { From 2cb8732c4a731d3ddee528bf15e383c810f982ed Mon Sep 17 00:00:00 2001 From: Ferdinand Lemaire Date: Thu, 31 Jul 2025 13:45:21 +0900 Subject: [PATCH 06/14] Formatting and other comments from the previous PR --- mlir/include/mlir/InitAllTranslations.h | 1 + mlir/include/mlir/Target/Wasm/WasmImporter.h | 9 +- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 428 +++++++++--------- .../lib/Target/Wasm/TranslateRegistration.cpp | 16 +- mlir/test/Target/Wasm/bad_wasm_version.yaml | 2 +- .../Wasm/function_export_out_of_scope.yaml | 2 +- .../Wasm/invalid_function_type_index.yaml | 2 +- mlir/test/Target/Wasm/missing_header.yaml | 2 +- 8 files changed, 232 insertions(+), 230 deletions(-) diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h index cf8f108b88159..622024db5a8a2 100644 --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -17,6 +17,7 @@ #include "mlir/Target/IRDLToCpp/TranslationRegistration.h" namespace mlir { + void registerFromLLVMIRTranslation(); void registerFromSPIRVTranslation(); void registerFromWasmTranslation(); diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h index fc7d275353964..5cc42a1f32fa4 100644 --- a/mlir/include/mlir/Target/Wasm/WasmImporter.h +++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h @@ -19,8 +19,7 @@ #include "mlir/IR/OwningOpRef.h" #include "llvm/Support/SourceMgr.h" -namespace mlir { -namespace wasm { +namespace mlir::wasm { /// Translates the given operation to C++ code. The operation or operations in /// the region of 'op' need almost all be in EmitC dialect. The parameter @@ -28,8 +27,8 @@ namespace wasm { /// arguments are declared at the beginning of the function. /// If parameter 'fileId' is non-empty, then body of `emitc.file` ops /// with matching id are emitted. -OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext* context); -} // namespace wasm -} // namespace mlir +OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, + MLIRContext *context); +} // namespace mlir::wasm #endif // MLIR_TARGET_WASM_WASMIMPORTER_H diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 753513eb9d887..dd8b86670c31a 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -9,6 +9,8 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" #include "mlir/Target/Wasm/WasmBinaryEncoding.h" #include "mlir/Target/Wasm/WasmImporter.h" #include "llvm/ADT/Statistic.h" @@ -16,6 +18,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LEB128.h" +#include #include #define DEBUG_TYPE "wasm-translate" @@ -26,7 +29,8 @@ STATISTIC(numGlobalSectionItems, "Parsed globals"); STATISTIC(numMemorySectionItems, "Parsed memories"); STATISTIC(numTableSectionItems, "Parsed tables"); -static_assert(CHAR_BIT == 8, "This code expects std::byte to be exactly 8 bits"); +static_assert(CHAR_BIT == 8, + "This code expects std::byte to be exactly 8 bits"); using namespace mlir; using namespace mlir::wasm; @@ -51,7 +55,7 @@ enum struct WasmSectionType : section_id_t { }; constexpr section_id_t highestWasmSectionID{ - static_cast(WasmSectionType::DATACOUNT)}; + static_cast(WasmSectionType::DATACOUNT)}; #define APPLY_WASM_SEC_TRANSFORM \ WASM_SEC_TRANSFORM(CUSTOM) \ @@ -82,7 +86,7 @@ constexpr bool sectionShouldBeUnique(WasmSectionType secType) { } template -struct ByteSequence{}; +struct ByteSequence {}; template constexpr ByteSequence @@ -91,7 +95,7 @@ operator+(ByteSequence, ByteSequence) { } /// Template class for representing a byte sequence of only one byte -template +template struct UniqueByte : ByteSequence {}; template @@ -109,12 +113,13 @@ constexpr ByteSequence< WasmBinaryEncoding::Type::v128> valueTypesEncodings{}; -template -constexpr bool isValueOneOf(std::byte value, ByteSequence = {}) { - return ((value == allowedFlags) | ... | false); +template +constexpr bool isValueOneOf(std::byte value, + ByteSequence = {}) { + return ((value == allowedFlags) | ... | false); } -template +template constexpr bool isNotIn(std::byte value, ByteSequence = {}) { return !isValueOneOf(value); } @@ -140,19 +145,20 @@ struct FunctionSymbolRefContainer : SymbolRefContainer { FunctionType functionType; }; -using ImportDesc = std::variant; +using ImportDesc = + std::variant; -using parsed_inst_t = llvm::FailureOr>; +using parsed_inst_t = FailureOr>; struct WasmModuleSymbolTables { - llvm::SmallVector funcSymbols; - llvm::SmallVector globalSymbols; - llvm::SmallVector memSymbols; - llvm::SmallVector tableSymbols; - llvm::SmallVector moduleFuncTypes; - - std::string getNewSymbolName(llvm::StringRef prefix, size_t id) const { - return (prefix + llvm::Twine{id}).str(); + SmallVector funcSymbols; + SmallVector globalSymbols; + SmallVector memSymbols; + SmallVector tableSymbols; + SmallVector moduleFuncTypes; + + std::string getNewSymbolName(StringRef prefix, size_t id) const { + return (prefix + Twine{id}).str(); } std::string getNewFuncSymbolName() const { @@ -187,6 +193,7 @@ class ValueStack { size_t stackIdx; LabelLevelOpInterface levelOp; }; + public: bool empty() const { return values.empty(); } @@ -200,8 +207,8 @@ class ValueStack { /// location /// if an error occurs. /// @return Failure or the vector of popped values. - llvm::FailureOr> popOperands(TypeRange operandTypes, - Location *opLoc); + FailureOr> popOperands(TypeRange operandTypes, + Location *opLoc); /// Push the results of an operation to the stack so they can be used in a /// following operation. @@ -211,7 +218,6 @@ class ValueStack { /// if an error occurs. LogicalResult pushResults(ValueRange results, Location *opLoc); - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// A simple dump function for debugging. /// Writes output to llvm::dbgs(). @@ -219,16 +225,16 @@ class ValueStack { #endif private: - llvm::SmallVector values; + SmallVector values; }; using local_val_t = TypedValue; class ExpressionParser { public: - using locals_t = llvm::SmallVector; + using locals_t = SmallVector; ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols, - llvm::ArrayRef initLocal) + ArrayRef initLocal) : parser{parser}, symbols{symbols}, locals{initLocal} {} private: @@ -240,7 +246,6 @@ class ExpressionParser { parseConstInst(OpBuilder &builder, std::enable_if_t> * = nullptr); - /// This function generates a dispatch tree to associate an opcode with a /// parser. Parsers are registered by specialising the /// `parseSpecificInstruction` function for the op code to handle. @@ -263,8 +268,9 @@ class ExpressionParser { constexpr std::byte nextHighBitPatternStem = highBitPattern << 1; constexpr size_t nextPatternBitSize = patternBitSize + 1; if ((opCode & bitSelect) != std::byte{0}) - return dispatchToInstParser < nextPatternBitSize, - nextHighBitPatternStem | std::byte{1} > (opCode, builder); + return dispatchToInstParser( + opCode, builder); return dispatchToInstParser( opCode, builder); } else { @@ -273,40 +279,40 @@ class ExpressionParser { } struct ParseResultWithInfo { - llvm::SmallVector opResults; + SmallVector opResults; std::byte endingByte; }; public: - template - parsed_inst_t parse(OpBuilder &builder, - UniqueByte = {}); + template + parsed_inst_t parse(OpBuilder &builder, UniqueByte = {}); template - llvm::FailureOr + FailureOr parse(OpBuilder &builder, ByteSequence parsingEndFilters); - llvm::FailureOr> - popOperands(TypeRange operandTypes) { + FailureOr> popOperands(TypeRange operandTypes) { return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); } LogicalResult pushResults(ValueRange results) { return valueStack.pushResults(results, ¤tOpLoc.value()); } + private: std::optional currentOpLoc; ParserHead &parser; WasmModuleSymbolTables const &symbols; locals_t locals; ValueStack valueStack; - }; +}; class ParserHead { public: - ParserHead(llvm::StringRef src, StringAttr name) : head{src}, locName{name} {} + ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {} ParserHead(ParserHead &&) = default; + private: ParserHead(ParserHead const &other) = default; @@ -315,7 +321,7 @@ class ParserHead { return FileLineColLoc::get(locName, 0, anchorOffset + offset); } - llvm::FailureOr consumeNBytes(size_t nBytes) { + FailureOr consumeNBytes(size_t nBytes) { LLVM_DEBUG(llvm::dbgs() << "Consume " << nBytes << " bytes\n"); LLVM_DEBUG(llvm::dbgs() << " Bytes remaining: " << size() << "\n"); LLVM_DEBUG(llvm::dbgs() << " Current offset: " << offset << "\n"); @@ -323,14 +329,14 @@ class ParserHead { return emitError(getLocation(), "trying to extract ") << nBytes << "bytes when only " << size() << "are avilables"; - auto res = head.slice(offset, offset + nBytes); + StringRef res = head.slice(offset, offset + nBytes); offset += nBytes; LLVM_DEBUG(llvm::dbgs() << " Updated offset (+" << nBytes << "): " << offset << "\n"); return res; } - llvm::FailureOr consumeByte() { + FailureOr consumeByte() { auto res = consumeNBytes(1); if (failed(res)) return failure(); @@ -338,52 +344,52 @@ class ParserHead { } template - llvm::FailureOr parseLiteral(); + FailureOr parseLiteral(); - llvm::FailureOr parseVectorSize(); + FailureOr parseVectorSize(); private: // TODO: This is equivalent to parseLiteral and could be removed // if parseLiteral specialization were moved here, but default GCC on Ubuntu // 22.04 has bug with template specialization in class declaration - inline llvm::FailureOr parseUI32(); - inline llvm::FailureOr parseI64(); + inline FailureOr parseUI32(); + inline FailureOr parseI64(); public: - llvm::FailureOr parseName() { - auto size = parseVectorSize(); + FailureOr parseName() { + FailureOr size = parseVectorSize(); if (failed(size)) return failure(); return consumeNBytes(*size); } - llvm::FailureOr parseWasmSectionType() { - auto id = consumeByte(); + FailureOr parseWasmSectionType() { + FailureOr id = consumeByte(); if (failed(id)) return failure(); if (std::to_integer(*id) > highestWasmSectionID) - return emitError(getLocation(), "Invalid section ID: ") + return emitError(getLocation(), "invalid section ID: ") << static_cast(*id); return static_cast(*id); } - llvm::FailureOr parseLimit(MLIRContext *ctx) { + FailureOr parseLimit(MLIRContext *ctx) { using WasmLimits = WasmBinaryEncoding::LimitHeader; - auto limitLocation = getLocation(); - auto limitHeader = consumeByte(); + FileLineColLoc limitLocation = getLocation(); + FailureOr limitHeader = consumeByte(); if (failed(limitHeader)) return failure(); if (isNotIn(*limitHeader)) - return emitError(limitLocation, "Invalid limit header: ") + return emitError(limitLocation, "invalid limit header: ") << static_cast(*limitHeader); - auto minParse = parseUI32(); + FailureOr minParse = parseUI32(); if (failed(minParse)) return failure(); std::optional max{std::nullopt}; if (*limitHeader == WasmLimits::bothLimits) { - auto maxParse = parseUI32(); + FailureOr maxParse = parseUI32(); if (failed(maxParse)) return failure(); max = *maxParse; @@ -391,9 +397,9 @@ class ParserHead { return LimitType::get(ctx, *minParse, max); } - llvm::FailureOr parseValueType(MLIRContext *ctx) { - auto typeLoc = getLocation(); - auto typeEncoding = consumeByte(); + FailureOr parseValueType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr typeEncoding = consumeByte(); if (failed(typeEncoding)) return failure(); switch (*typeEncoding) { @@ -412,35 +418,35 @@ class ParserHead { case WasmBinaryEncoding::Type::externRef: return wasmssa::ExternRefType::get(ctx); default: - return emitError(typeLoc, "Invalid value type encoding: ") + return emitError(typeLoc, "invalid value type encoding: ") << static_cast(*typeEncoding); } } - llvm::FailureOr parseGlobalType(MLIRContext *ctx) { + FailureOr parseGlobalType(MLIRContext *ctx) { using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability; - auto typeParsed = parseValueType(ctx); + FailureOr typeParsed = parseValueType(ctx); if (failed(typeParsed)) return failure(); - auto mutLoc = getLocation(); - auto mutSpec = consumeByte(); + FileLineColLoc mutLoc = getLocation(); + FailureOr mutSpec = consumeByte(); if (failed(mutSpec)) return failure(); if (isNotIn(*mutSpec)) - return emitError(mutLoc, "Invalid global mutability specifier: ") + return emitError(mutLoc, "invalid global mutability specifier: ") << static_cast(*mutSpec); return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable}; } - llvm::FailureOr parseResultType(MLIRContext *ctx) { - auto nParamsParsed = parseVectorSize(); + FailureOr parseResultType(MLIRContext *ctx) { + FailureOr nParamsParsed = parseVectorSize(); if (failed(nParamsParsed)) return failure(); - auto nParams = *nParamsParsed; - llvm::SmallVector res{}; + uint32_t nParams = *nParamsParsed; + SmallVector res{}; res.reserve(nParams); for (size_t i = 0; i < nParams; ++i) { - auto parsedType = parseValueType(ctx); + FailureOr parsedType = parseValueType(ctx); if (failed(parsedType)) return failure(); res.push_back(*parsedType); @@ -448,50 +454,49 @@ class ParserHead { return TupleType::get(ctx, res); } - llvm::FailureOr parseFunctionType(MLIRContext *ctx) { - auto typeLoc = getLocation(); - auto funcTypeHeader = consumeByte(); + FailureOr parseFunctionType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr funcTypeHeader = consumeByte(); if (failed(funcTypeHeader)) return failure(); if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType) - return emitError(typeLoc, "Invalid function type header byte. Expecting ") - << std::to_integer( - WasmBinaryEncoding::Type::funcType) + return emitError(typeLoc, "invalid function type header byte. Expecting ") + << std::to_integer(WasmBinaryEncoding::Type::funcType) << " got " << std::to_integer(*funcTypeHeader); - auto inputTypes = parseResultType(ctx); + FailureOr inputTypes = parseResultType(ctx); if (failed(inputTypes)) return failure(); - auto resTypes = parseResultType(ctx); + FailureOr resTypes = parseResultType(ctx); if (failed(resTypes)) return failure(); return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes()); } - llvm::FailureOr parseTypeIndex() { - auto res = parseUI32(); + FailureOr parseTypeIndex() { + FailureOr res = parseUI32(); if (failed(res)) return failure(); return TypeIdxRecord{*res}; } - llvm::FailureOr parseTableType(MLIRContext *ctx) { - auto elmTypeParse = parseValueType(ctx); + FailureOr parseTableType(MLIRContext *ctx) { + FailureOr elmTypeParse = parseValueType(ctx); if (failed(elmTypeParse)) return failure(); if (!isWasmRefType(*elmTypeParse)) - return emitError(getLocation(), "Invalid element type for table"); - auto limitParse = parseLimit(ctx); + return emitError(getLocation(), "invalid element type for table"); + FailureOr limitParse = parseLimit(ctx); if (failed(limitParse)) return failure(); return TableType::get(ctx, *elmTypeParse, *limitParse); } - llvm::FailureOr parseImportDesc(MLIRContext *ctx) { - auto importLoc = getLocation(); - auto importType = consumeByte(); - auto packager = [](auto parseResult) -> llvm::FailureOr { + FailureOr parseImportDesc(MLIRContext *ctx) { + FileLineColLoc importLoc = getLocation(); + FailureOr importType = consumeByte(); + auto packager = [](auto parseResult) -> FailureOr { if (llvm::failed(parseResult)) return failure(); return {*parseResult}; @@ -508,28 +513,26 @@ class ParserHead { case WasmBinaryEncoding::Import::globalType: return packager(parseGlobalType(ctx)); default: - return emitError(importLoc, "Invalid import type descriptor: ") + return emitError(importLoc, "invalid import type descriptor: ") << static_cast(*importType); } } parsed_inst_t parseExpression(OpBuilder &builder, WasmModuleSymbolTables const &symbols, - llvm::ArrayRef locals = {}) { + ArrayRef locals = {}) { auto eParser = ExpressionParser{*this, symbols, locals}; return eParser.parse(builder); } bool end() const { return curHead().empty(); } - ParserHead copy() const { - return *this; - } + ParserHead copy() const { return *this; } private: - llvm::StringRef curHead() const { return head.drop_front(offset); } + StringRef curHead() const { return head.drop_front(offset); } - llvm::FailureOr peek() const { + FailureOr peek() const { if (end()) return emitError( getLocation(), @@ -539,14 +542,14 @@ class ParserHead { size_t size() const { return head.size() - offset; } - llvm::StringRef head; + StringRef head; StringAttr locName; unsigned anchorOffset{0}; unsigned offset{0}; }; template <> -llvm::FailureOr ParserHead::parseLiteral() { +FailureOr ParserHead::parseLiteral() { auto bytes = consumeNBytes(4); if (failed(bytes)) return failure(); @@ -556,7 +559,7 @@ llvm::FailureOr ParserHead::parseLiteral() { } template <> -llvm::FailureOr ParserHead::parseLiteral() { +FailureOr ParserHead::parseLiteral() { auto bytes = consumeNBytes(8); if (failed(bytes)) return failure(); @@ -566,13 +569,13 @@ llvm::FailureOr ParserHead::parseLiteral() { } template <> -llvm::FailureOr ParserHead::parseLiteral() { +FailureOr ParserHead::parseLiteral() { char const *error = nullptr; uint32_t res{0}; unsigned encodingSize{0}; - auto src = curHead(); - auto decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize, - src.bytes_end(), &error); + StringRef src = curHead(); + uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); if (error) return emitError(getLocation(), error); @@ -585,13 +588,13 @@ llvm::FailureOr ParserHead::parseLiteral() { } template <> -llvm::FailureOr ParserHead::parseLiteral() { +FailureOr ParserHead::parseLiteral() { char const *error = nullptr; int32_t res{0}; unsigned encodingSize{0}; - auto src = curHead(); - auto decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, - src.bytes_end(), &error); + StringRef src = curHead(); + int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); if (error) return emitError(getLocation(), error); if (std::isgreater(decoded, std::numeric_limits::max()) || @@ -604,12 +607,12 @@ llvm::FailureOr ParserHead::parseLiteral() { } template <> -llvm::FailureOr ParserHead::parseLiteral() { +FailureOr ParserHead::parseLiteral() { char const *error = nullptr; unsigned encodingSize{0}; - auto src = curHead(); - auto res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, - src.bytes_end(), &error); + StringRef src = curHead(); + int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); if (error) return emitError(getLocation(), error); @@ -617,21 +620,21 @@ llvm::FailureOr ParserHead::parseLiteral() { return res; } -llvm::FailureOr ParserHead::parseVectorSize() { +FailureOr ParserHead::parseVectorSize() { return parseLiteral(); } -inline llvm::FailureOr ParserHead::parseUI32() { +inline FailureOr ParserHead::parseUI32() { return parseLiteral(); } -inline llvm::FailureOr ParserHead::parseI64() { +inline FailureOr ParserHead::parseI64() { return parseLiteral(); } template inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { - return emitError(*currentOpLoc, "Unknown instruction opcode: ") + return emitError(*currentOpLoc, "unknown instruction opcode: ") << static_cast(opCode); } @@ -645,7 +648,7 @@ void ValueStack::dump() const { // end of the vector. Iterate in reverse so that the first thing we print // is the top of the stack. size_t stackSize = size(); - for (size_t idx = 0 ; idx < stackSize ;) { + for (size_t idx = 0; idx < stackSize;) { size_t actualIdx = stackSize - 1 - idx; llvm::dbgs() << " "; values[actualIdx].dump(); @@ -663,18 +666,17 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n"); if (operandTypes.size() > values.size()) return emitError(*opLoc, - "Stack doesn't contain enough values. Trying to get ") + "stack doesn't contain enough values. Trying to get ") << operandTypes.size() << " operands on a stack containing only " << values.size() << " values."; size_t stackIdxOffset = values.size() - operandTypes.size(); - llvm::SmallVector res{}; + SmallVector res{}; res.reserve(operandTypes.size()); for (size_t i{0}; i < operandTypes.size(); ++i) { Value operand = values[i + stackIdxOffset]; Type stackType = operand.getType(); if (stackType != operandTypes[i]) - return emitError(*opLoc, - "Invalid operand type on stack. Expecting ") + return emitError(*opLoc, "invalid operand type on stack. Expecting ") << operandTypes[i] << ", value on stack is of type " << stackType << "."; LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n"); @@ -690,9 +692,9 @@ LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size() << "\n"); LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n"); - for (auto val : results) { + for (Value val : results) { if (!isWasmValueType(val.getType())) - return emitError(*opLoc, "Invalid value type on stack: ") + return emitError(*opLoc, "invalid value type on stack: ") << val.getType(); LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n"); values.push_back(val); @@ -702,8 +704,9 @@ LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { return success(); } -template -parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte endByte) { +template +parsed_inst_t ExpressionParser::parse(OpBuilder &builder, + UniqueByte endByte) { auto res = parse(builder, ByteSequence{}); if (failed(res)) return failure(); @@ -711,13 +714,13 @@ parsed_inst_t ExpressionParser::parse(OpBuilder &builder, UniqueByte -llvm::FailureOr +FailureOr ExpressionParser::parse(OpBuilder &builder, ByteSequence parsingEndFilters) { - llvm::SmallVector res; + SmallVector res; for (;;) { currentOpLoc = parser.getLocation(); - auto opCode = parser.consumeByte(); + FailureOr opCode = parser.consumeByte(); if (failed(opCode)) return failure(); if (isValueOneOf(*opCode, parsingEndFilters)) @@ -732,7 +735,6 @@ ExpressionParser::parse(OpBuilder &builder, } } - template inline Type buildLiteralType(OpBuilder &); @@ -766,7 +768,8 @@ inline Type buildLiteralType(OpBuilder &builder) { return builder.getF64Type(); } -template>> +template >> struct AttrHolder; template @@ -779,7 +782,7 @@ struct AttrHolder>> { using type = FloatAttr; }; -template +template using attr_holder_t = typename AttrHolder::type; template (builder); } - class WasmBinaryParser { private: struct SectionRegistry { - using section_location_t = llvm::StringRef; + using section_location_t = StringRef; - std::array, highestWasmSectionID+1> registry; + std::array, highestWasmSectionID + 1> + registry; template std::conditional_t, - llvm::ArrayRef> + ArrayRef> getContentForSection() const { constexpr auto idx = static_cast(SecType); if constexpr (sectionShouldBeUnique(SecType)) { @@ -858,7 +861,7 @@ class WasmBinaryParser { section_location_t location, Location loc) { if (sectionShouldBeUnique(secType) && hasSection(secType)) return emitError(loc, - "Trying to add a second instance of unique section"); + "trying to add a second instance of unique section"); registry[static_cast(secType)].push_back(location); emitRemark(loc, "Adding section with section ID ") @@ -868,26 +871,25 @@ class WasmBinaryParser { LogicalResult populateFromBody(ParserHead ph) { while (!ph.end()) { - auto sectionLoc = ph.getLocation(); - auto secType = ph.parseWasmSectionType(); + FileLineColLoc sectionLoc = ph.getLocation(); + FailureOr secType = ph.parseWasmSectionType(); if (failed(secType)) return failure(); - auto secSizeParsed = ph.parseLiteral(); + FailureOr secSizeParsed = ph.parseLiteral(); if (failed(secSizeParsed)) return failure(); - auto secSize = *secSizeParsed; - auto sectionContent = ph.consumeNBytes(secSize); + uint32_t secSize = *secSizeParsed; + FailureOr sectionContent = ph.consumeNBytes(secSize); if (failed(sectionContent)) return failure(); - auto registration = + LogicalResult registration = registerSection(*secType, *sectionContent, sectionLoc); if (failed(registration)) return failure(); - } return success(); } @@ -917,10 +919,10 @@ class WasmBinaryParser { auto secSrc = secContent.value(); ParserHead ph{secSrc, sectionNameAttr}; - auto nElemsParsed = ph.parseVectorSize(); + FailureOr nElemsParsed = ph.parseVectorSize(); if (failed(nElemsParsed)) return failure(); - auto nElems = *nElemsParsed; + uint32_t nElems = *nElemsParsed; LLVM_DEBUG(llvm::dbgs() << "Starting to parse " << nElems << " items for section " << secName << ".\n"); for (size_t i = 0; i < nElems; ++i) { @@ -929,31 +931,31 @@ class WasmBinaryParser { } if (!ph.end()) - return emitError(getLocation(), "Unparsed garbage at end of section ") + return emitError(getLocation(), "unparsed garbage at end of section ") << secName; return success(); } /// Handles the registration of a function import - LogicalResult visitImport(Location loc, llvm::StringRef moduleName, - llvm::StringRef importName, TypeIdxRecord tid) { + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TypeIdxRecord tid) { using llvm::Twine; if (tid.id >= symbols.moduleFuncTypes.size()) - return emitError(loc, "Invalid type id: ") + return emitError(loc, "invalid type id: ") << tid.id << ". Only " << symbols.moduleFuncTypes.size() << " type registration."; - auto type = symbols.moduleFuncTypes[tid.id]; - auto symbol = symbols.getNewFuncSymbolName(); - auto funcOp = builder.create( - loc, symbol, moduleName, importName, type); + FunctionType type = symbols.moduleFuncTypes[tid.id]; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = + builder.create(loc, symbol, moduleName, importName, type); symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type}); return funcOp.verify(); } /// Handles the registration of a memory import - LogicalResult visitImport(Location loc, llvm::StringRef moduleName, - llvm::StringRef importName, LimitType limitType) { - auto symbol = symbols.getNewMemorySymbolName(); + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, LimitType limitType) { + std::string symbol = symbols.getNewMemorySymbolName(); auto memOp = builder.create(loc, symbol, moduleName, importName, limitType); symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)}); @@ -961,9 +963,9 @@ class WasmBinaryParser { } /// Handles the registration of a table import - LogicalResult visitImport(Location loc, llvm::StringRef moduleName, - llvm::StringRef importName, TableType tableType) { - auto symbol = symbols.getNewTableSymbolName(); + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TableType tableType) { + std::string symbol = symbols.getNewTableSymbolName(); auto tableOp = builder.create(loc, symbol, moduleName, importName, tableType); symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)}); @@ -971,14 +973,14 @@ class WasmBinaryParser { } /// Handles the registration of a global variable import - LogicalResult visitImport(Location loc, llvm::StringRef moduleName, - llvm::StringRef importName, - GlobalTypeRecord globalType) { - auto symbol = symbols.getNewGlobalSymbolName(); + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, GlobalTypeRecord globalType) { + std::string symbol = symbols.getNewGlobalSymbolName(); auto giOp = builder.create(loc, symbol, moduleName, importName, globalType.type, globalType.isMutable); - symbols.globalSymbols.push_back({{FlatSymbolRefAttr::get(giOp)}, giOp.getType()}); + symbols.globalSymbols.push_back( + {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()}); return giOp.verify(); } @@ -987,63 +989,62 @@ class WasmBinaryParser { : builder{ctx}, ctx{ctx} { ctx->loadAllAvailableDialects(); if (sourceMgr.getNumBuffers() != 1) { - emitError(UnknownLoc::get(ctx), "One source file should be provided"); + emitError(UnknownLoc::get(ctx), "one source file should be provided"); return; } - auto sourceBufId = sourceMgr.getMainFileID(); - auto source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer(); + uint32_t sourceBufId = sourceMgr.getMainFileID(); + StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer(); srcName = StringAttr::get( - ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier()); + ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier()); auto parser = ParserHead{source, srcName}; auto const wasmHeader = StringRef{"\0asm", 4}; - auto magicLoc = parser.getLocation(); - auto magic = parser.consumeNBytes(wasmHeader.size()); + FileLineColLoc magicLoc = parser.getLocation(); + FailureOr magic = parser.consumeNBytes(wasmHeader.size()); if (failed(magic) || magic->compare(wasmHeader)) { - emitError(magicLoc, - "Source file does not contain valid Wasm header."); + emitError(magicLoc, "source file does not contain valid Wasm header."); return; } auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; - auto versionLoc = parser.getLocation(); - auto version = parser.consumeNBytes(expectedVersionString.size()); + FileLineColLoc versionLoc = parser.getLocation(); + FailureOr version = + parser.consumeNBytes(expectedVersionString.size()); if (failed(version)) return; if (version->compare(expectedVersionString)) { emitError(versionLoc, - "Unsupported Wasm version. Only version 1 is supported."); + "unsupported Wasm version. Only version 1 is supported."); return; } - auto fillRegistry = registry.populateFromBody(parser.copy()); + LogicalResult fillRegistry = registry.populateFromBody(parser.copy()); if (failed(fillRegistry)) return; mOp = builder.create(getLocation()); - builder.setInsertionPointToStart( - &mOp.getBodyRegion().front()); - auto parsingTypes = parseSection(); + builder.setInsertionPointToStart(&mOp.getBodyRegion().front()); + LogicalResult parsingTypes = parseSection(); if (failed(parsingTypes)) return; - auto parsingImports = parseSection(); + LogicalResult parsingImports = parseSection(); if (failed(parsingImports)) return; firstInternalFuncID = symbols.funcSymbols.size(); - auto parsingFunctions = parseSection(); + LogicalResult parsingFunctions = parseSection(); if (failed(parsingFunctions)) return; - auto parsingTables = parseSection(); + LogicalResult parsingTables = parseSection(); if (failed(parsingTables)) return; - auto parsingMems = parseSection(); + LogicalResult parsingMems = parseSection(); if (failed(parsingMems)) return; - auto parsingExports = parseSection(); + LogicalResult parsingExports = parseSection(); if (failed(parsingExports)) return; @@ -1068,8 +1069,9 @@ class WasmBinaryParser { template <> LogicalResult -WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { - auto importLoc = ph.getLocation(); +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc importLoc = ph.getLocation(); auto moduleName = ph.parseName(); if (failed(moduleName)) return failure(); @@ -1078,7 +1080,7 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, size if (failed(importName)) return failure(); - auto import = ph.parseImportDesc(ctx); + FailureOr import = ph.parseImportDesc(ctx); if (failed(import)) return failure(); @@ -1093,24 +1095,23 @@ template <> LogicalResult WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { - auto exportLoc = ph.getLocation(); + FileLineColLoc exportLoc = ph.getLocation(); auto exportName = ph.parseName(); if (failed(exportName)) return failure(); - auto opcode = ph.consumeByte(); + FailureOr opcode = ph.consumeByte(); if (failed(opcode)) return failure(); - auto idx = ph.parseLiteral(); + FailureOr idx = ph.parseLiteral(); if (failed(idx)) return failure(); - using SymbolRefDesc = - std::variant, - llvm::SmallVector, - llvm::SmallVector>; + using SymbolRefDesc = std::variant, + SmallVector, + SmallVector>; SymbolRefDesc currentSymbolList; std::string symbolType = ""; @@ -1132,7 +1133,7 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, currentSymbolList = symbols.globalSymbols; break; default: - return emitError(exportLoc, "Invalid value for export type: ") + return emitError(exportLoc, "invalid value for export type: ") << std::to_integer(*opcode); } @@ -1142,7 +1143,7 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, emitError( exportLoc, llvm::formatv( - "Trying to export {0} {1} which is undefined in this scope", + "trying to export {0} {1} which is undefined in this scope", symbolType, *idx)); return failure(); } @@ -1155,21 +1156,23 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); - auto symName = SymbolTable::getSymbolName(op); + StringAttr symName = SymbolTable::getSymbolName(op); return SymbolTable{mOp}.rename(symName, *exportName); } template <> LogicalResult -WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { - auto opLocation = ph.getLocation(); - auto tableType = ph.parseTableType(ctx); +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr tableType = ph.parseTableType(ctx); if (failed(tableType)) return failure(); LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType << '\n'); - auto symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); - auto tableOp = builder.create(opLocation, symbol.strref(), *tableType); + StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); + auto tableOp = + builder.create(opLocation, symbol.strref(), *tableType); symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)}); return success(); } @@ -1178,17 +1181,17 @@ template <> LogicalResult WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { - auto opLoc = ph.getLocation(); + FileLineColLoc opLoc = ph.getLocation(); auto typeIdxParsed = ph.parseLiteral(); if (failed(typeIdxParsed)) return failure(); - auto typeIdx = *typeIdxParsed; + uint32_t typeIdx = *typeIdxParsed; if (typeIdx >= symbols.moduleFuncTypes.size()) - return emitError(getLocation(), "Invalid type index: ") << typeIdx; - auto symbol = symbols.getNewFuncSymbolName(); + return emitError(getLocation(), "invalid type index: ") << typeIdx; + std::string symbol = symbols.getNewFuncSymbolName(); auto funcOp = builder.create(opLoc, symbol, symbols.moduleFuncTypes[typeIdx]); - auto *block = funcOp.addEntryBlock(); + Block *block = funcOp.addEntryBlock(); auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToEnd(block); builder.create(opLoc); @@ -1203,7 +1206,7 @@ template <> LogicalResult WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { - auto funcType = ph.parseFunctionType(ctx); + FailureOr funcType = ph.parseFunctionType(ctx); if (failed(funcType)) return failure(); LLVM_DEBUG(llvm::dbgs() << "Parsed function type " << *funcType << '\n'); @@ -1213,30 +1216,29 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, template <> LogicalResult -WasmBinaryParser::parseSectionItem(ParserHead &ph, size_t) { - auto opLocation = ph.getLocation(); - auto memory = ph.parseLimit(ctx); +WasmBinaryParser::parseSectionItem(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr memory = ph.parseLimit(ctx); if (failed(memory)) return failure(); LLVM_DEBUG(llvm::dbgs() << " Registering memory " << *memory << '\n'); - auto symbol = symbols.getNewMemorySymbolName(); + std::string symbol = symbols.getNewMemorySymbolName(); auto memOp = builder.create(opLocation, symbol, *memory); symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)}); return success(); } } // namespace -namespace mlir { -namespace wasm { +namespace mlir::wasm { OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context) { WasmBinaryParser wBN{source, context}; - auto mOp = wBN.getModule(); + ModuleOp mOp = wBN.getModule(); if (mOp) return {mOp}; return {nullptr}; } -} // namespace wasm -} // namespace mlir +} // namespace mlir::wasm diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp index 9c0f7702a96aa..03b97846d45d3 100644 --- a/mlir/lib/Target/Wasm/TranslateRegistration.cpp +++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp @@ -11,18 +11,18 @@ #include "mlir/Target/Wasm/WasmImporter.h" #include "mlir/Tools/mlir-translate/Translation.h" - using namespace mlir; namespace mlir { void registerFromWasmTranslation() { TranslateToMLIRRegistration registration{ - "import-wasm", "Translate WASM to MLIR", - [](llvm::SourceMgr &sourceMgr, MLIRContext* context) -> OwningOpRef { - return wasm::importWebAssemblyToModule(sourceMgr, context); - }, [](DialectRegistry& registry) { - registry.insert(); - } - }; + "import-wasm", "Translate WASM to MLIR", + [](llvm::SourceMgr &sourceMgr, + MLIRContext *context) -> OwningOpRef { + return wasm::importWebAssemblyToModule(sourceMgr, context); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }}; } } // namespace mlir diff --git a/mlir/test/Target/Wasm/bad_wasm_version.yaml b/mlir/test/Target/Wasm/bad_wasm_version.yaml index 4fed1d5a3af3c..f834afbef679d 100644 --- a/mlir/test/Target/Wasm/bad_wasm_version.yaml +++ b/mlir/test/Target/Wasm/bad_wasm_version.yaml @@ -1,6 +1,6 @@ # RUN: yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s -# CHECK: Unsupported Wasm version +# CHECK: unsupported Wasm version --- !WASM FileHeader: diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml index ffb26f563141a..5adbd861bad36 100644 --- a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml +++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml @@ -2,7 +2,7 @@ # FIXME: The error code here should be nonzero. -# CHECK: Trying to export function 42 which is undefined in this scope +# CHECK: trying to export function 42 which is undefined in this scope --- !WASM FileHeader: diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml index 961e9cc6e8029..b01a623c41209 100644 --- a/mlir/test/Target/Wasm/invalid_function_type_index.yaml +++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml @@ -1,5 +1,5 @@ # RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s -# CHECK: error: Invalid type index: 2 +# CHECK: error: invalid type index: 2 # FIXME: mlir-translate should not return 0 here. diff --git a/mlir/test/Target/Wasm/missing_header.yaml b/mlir/test/Target/Wasm/missing_header.yaml index 5610f9c5c6e33..a9f812e0a77f8 100644 --- a/mlir/test/Target/Wasm/missing_header.yaml +++ b/mlir/test/Target/Wasm/missing_header.yaml @@ -1,6 +1,6 @@ # RUN: not yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s -# CHECK: Source file does not contain valid Wasm header +# CHECK: source file does not contain valid Wasm header --- !WASM Sections: From 6f6c52785bedfaed8ef46426c9cb1fba7e566f9f Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Thu, 7 Aug 2025 09:25:34 +0000 Subject: [PATCH 07/14] [MLIR][WASM] Make WasmSSA Importer signal program failure on error Also contains non functional changes to use the `LDBG` macro. --- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 61 +++++++++++-------- .../Wasm/function_export_out_of_scope.yaml | 4 +- .../Wasm/invalid_function_type_index.yaml | 4 +- 3 files changed, 38 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index dd8b86670c31a..ed3b43a4408ac 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -15,6 +15,7 @@ #include "mlir/Target/Wasm/WasmImporter.h" #include "llvm/ADT/Statistic.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LEB128.h" @@ -322,17 +323,16 @@ class ParserHead { } FailureOr consumeNBytes(size_t nBytes) { - LLVM_DEBUG(llvm::dbgs() << "Consume " << nBytes << " bytes\n"); - LLVM_DEBUG(llvm::dbgs() << " Bytes remaining: " << size() << "\n"); - LLVM_DEBUG(llvm::dbgs() << " Current offset: " << offset << "\n"); + LDBG() << "Consume " << nBytes << " bytes"; + LDBG() << " Bytes remaining: " << size(); + LDBG() << " Current offset: " << offset; if (nBytes > size()) return emitError(getLocation(), "trying to extract ") << nBytes << "bytes when only " << size() << "are avilables"; StringRef res = head.slice(offset, offset + nBytes); offset += nBytes; - LLVM_DEBUG(llvm::dbgs() - << " Updated offset (+" << nBytes << "): " << offset << "\n"); + LDBG() << " Updated offset (+" << nBytes << "): " << offset; return res; } @@ -660,10 +660,9 @@ void ValueStack::dump() const { #endif parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { - LLVM_DEBUG(llvm::dbgs() << "Popping from ValueStack\n"); - LLVM_DEBUG(llvm::dbgs() << " Elements(s) to pop: " << operandTypes.size() - << "\n"); - LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n"); + LDBG() << "Popping from ValueStack\n" + << " Elements(s) to pop: " << operandTypes.size() << "\n" + << " Current stack size: " << values.size(); if (operandTypes.size() > values.size()) return emitError(*opLoc, "stack doesn't contain enough values. Trying to get ") @@ -679,28 +678,27 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { return emitError(*opLoc, "invalid operand type on stack. Expecting ") << operandTypes[i] << ", value on stack is of type " << stackType << "."; - LLVM_DEBUG(llvm::dbgs() << " POP: " << operand << "\n"); + LDBG() << " POP: " << operand; res.push_back(operand); } values.resize(values.size() - operandTypes.size()); - LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n"); + LDBG() << " Updated stack size: " << values.size(); return res; } LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { - LLVM_DEBUG(llvm::dbgs() << "Pushing to ValueStack\n"); - LLVM_DEBUG(llvm::dbgs() << " Elements(s) to push: " << results.size() - << "\n"); - LLVM_DEBUG(llvm::dbgs() << " Current stack size: " << values.size() << "\n"); + LDBG() << "Pushing to ValueStack\n" + << " Elements(s) to push: " << results.size() << "\n" + << " Current stack size: " << values.size(); for (Value val : results) { if (!isWasmValueType(val.getType())) return emitError(*opLoc, "invalid value type on stack: ") << val.getType(); - LLVM_DEBUG(llvm::dbgs() << " PUSH: " << val << "\n"); + LDBG() << " PUSH: " << val; values.push_back(val); } - LLVM_DEBUG(llvm::dbgs() << " Updated stack size: " << values.size() << "\n"); + LDBG() << " Updated stack size: " << values.size(); return success(); } @@ -913,7 +911,7 @@ class WasmBinaryParser { }; auto secContent = registry.getContentForSection
(); if (!secContent) { - LLVM_DEBUG(llvm::dbgs() << secName << " section is not present in file."); + LDBG() << secName << " section is not present in file."; return success(); } @@ -923,8 +921,8 @@ class WasmBinaryParser { if (failed(nElemsParsed)) return failure(); uint32_t nElems = *nElemsParsed; - LLVM_DEBUG(llvm::dbgs() << "Starting to parse " << nElems - << " items for section " << secName << ".\n"); + LDBG() << "Starting to parse " << nElems << " items for section " + << secName; for (size_t i = 0; i < nElems; ++i) { if (failed(parseSectionItem
(ph, i))) return failure(); @@ -984,9 +982,18 @@ class WasmBinaryParser { return giOp.verify(); } + // Detect occurence of errors + LogicalResult peekDiag(Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) + isValid = false; + return failure(); + } + public: WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx) : builder{ctx}, ctx{ctx} { + ctx->getDiagEngine().registerHandler( + [this](Diagnostic &diag) { return peekDiag(diag); }); ctx->loadAllAvailableDialects(); if (sourceMgr.getNumBuffers() != 1) { emitError(UnknownLoc::get(ctx), "one source file should be provided"); @@ -1055,7 +1062,11 @@ class WasmBinaryParser { numTableSectionItems = symbols.tableSymbols.size(); } - ModuleOp getModule() { return mOp; } + ModuleOp getModule() { + if (isValid) + return mOp; + return ModuleOp{}; + } private: mlir::StringAttr srcName; @@ -1065,6 +1076,7 @@ class WasmBinaryParser { ModuleOp mOp; SectionRegistry registry; size_t firstInternalFuncID{0}; + bool isValid{true}; }; template <> @@ -1168,8 +1180,7 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, FailureOr tableType = ph.parseTableType(ctx); if (failed(tableType)) return failure(); - LLVM_DEBUG(llvm::dbgs() << " Parsed table description: " << *tableType - << '\n'); + LDBG() << " Parsed table description: " << *tableType; StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); auto tableOp = builder.create(opLocation, symbol.strref(), *tableType); @@ -1209,7 +1220,7 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, FailureOr funcType = ph.parseFunctionType(ctx); if (failed(funcType)) return failure(); - LLVM_DEBUG(llvm::dbgs() << "Parsed function type " << *funcType << '\n'); + LDBG() << "Parsed function type " << *funcType; symbols.moduleFuncTypes.push_back(*funcType); return success(); } @@ -1223,7 +1234,7 @@ WasmBinaryParser::parseSectionItem(ParserHead &ph, if (failed(memory)) return failure(); - LLVM_DEBUG(llvm::dbgs() << " Registering memory " << *memory << '\n'); + LDBG() << " Registering memory " << *memory; std::string symbol = symbols.getNewMemorySymbolName(); auto memOp = builder.create(opLocation, symbol, *memory); symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)}); diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml index 5adbd861bad36..b08c2c87abdb3 100644 --- a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml +++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml @@ -1,6 +1,4 @@ -# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s - -# FIXME: The error code here should be nonzero. +# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s # CHECK: trying to export function 42 which is undefined in this scope diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml index b01a623c41209..2d2954aa32dda 100644 --- a/mlir/test/Target/Wasm/invalid_function_type_index.yaml +++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml @@ -1,8 +1,6 @@ -# RUN: yaml2obj %s | mlir-translate --import-wasm -o - 2>&1 | FileCheck %s +# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s # CHECK: error: invalid type index: 2 -# FIXME: mlir-translate should not return 0 here. - --- !WASM FileHeader: Version: 0x00000001 From b58ddbe8f8292ff26445fdb2a683b912c7719828 Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Fri, 8 Aug 2025 02:20:42 +0000 Subject: [PATCH 08/14] [MLIR] Adding yaml2obj as dependency to mlir unit tests This is due to the tests in the wasm importer requiring it. --- mlir/test/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 89568e7766ae5..c21e3610b9066 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -123,7 +123,7 @@ set(MLIR_TEST_DEPENDS tblgen-to-irdl ) if(NOT MLIR_STANDALONE_BUILD) - list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file) + list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file yaml2obj) endif() set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS} From ee8d3889ba59bd94ea202f3b6e4962df72ead4bd Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Tue, 12 Aug 2025 10:24:20 +0900 Subject: [PATCH 09/14] [WASM][MLIR] Importer tests uses new custom format --- mlir/test/Target/Wasm/memory_min_eq_max.mlir | 2 +- mlir/test/Target/Wasm/memory_min_max.mlir | 2 +- mlir/test/Target/Wasm/memory_min_no_max.mlir | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir index 088e28685d09a..2ba5ab50d51fa 100644 --- a/mlir/test/Target/Wasm/memory_min_eq_max.mlir +++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir @@ -4,4 +4,4 @@ (module (memory 0 0)) */ -// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa, sym_name = "mem_0", sym_visibility = "nested"}> : () -> () +// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir index 16d3468279d42..ebf64189189f8 100644 --- a/mlir/test/Target/Wasm/memory_min_max.mlir +++ b/mlir/test/Target/Wasm/memory_min_max.mlir @@ -4,4 +4,4 @@ (module (memory 0 65536)) */ -// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa, sym_name = "mem_0", sym_visibility = "nested"}> : () -> () +// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir index f71cb1098be18..8d8878618bcc0 100644 --- a/mlir/test/Target/Wasm/memory_min_no_max.mlir +++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir @@ -4,4 +4,4 @@ (module (memory 1)) */ -// CHECK-LABEL: "wasmssa.memory"() <{limits = !wasmssa, sym_name = "mem_0", sym_visibility = "nested"}> : () -> () +// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa From 6e952fc74b578e07b334ed4d4e995cfaed4e7186 Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Tue, 12 Aug 2025 13:48:20 +0900 Subject: [PATCH 10/14] [MLIR][WASM] Implement review remarks --- .../include/mlir/Target/Wasm/WasmBinaryEncoding.h | 15 ++++++++------- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 8 +++++++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h index a5b124eecbe67..3280432b5f038 100644 --- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -4,19 +4,20 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Define encodings for WebAssembly instructions, types, etc from the +// Define various flags used to encode instructions, types, etc. in // WebAssembly binary format. // -// Each encoding is defined in the WebAssembly binary specification. +// These encodings are defined in the WebAssembly binary format specification. // //===----------------------------------------------------------------------===// #ifndef MLIR_TARGET_WASMBINARYENCODING #define MLIR_TARGET_WASMBINARYENCODING #include + namespace mlir { struct WasmBinaryEncoding { - /// Byte encodings for WASM instructions. + /// Byte encodings for Wasm instructions. struct OpCode { // Locals, globals, constants. static constexpr std::byte constI32{0x41}; @@ -25,7 +26,7 @@ struct WasmBinaryEncoding { static constexpr std::byte constFP64{0x44}; }; - /// Byte encodings of types in WASM binaries + /// Byte encodings of types in Wasm binaries struct Type { static constexpr std::byte emptyBlockType{0x40}; static constexpr std::byte funcType{0x60}; @@ -38,7 +39,7 @@ struct WasmBinaryEncoding { static constexpr std::byte i32{0x7F}; }; - /// Byte encodings of WASM imports. + /// Byte encodings of Wasm imports. struct Import { static constexpr std::byte typeID{0x00}; static constexpr std::byte tableType{0x01}; @@ -46,7 +47,7 @@ struct WasmBinaryEncoding { static constexpr std::byte globalType{0x03}; }; - /// Byte encodings for WASM limits. + /// Byte encodings for Wasm limits. struct LimitHeader { static constexpr std::byte lowLimitOnly{0x00}; static constexpr std::byte bothLimits{0x01}; @@ -58,7 +59,7 @@ struct WasmBinaryEncoding { static constexpr std::byte isMutable{0x01}; }; - /// Byte encodings describing WASM exports. + /// Byte encodings describing Wasm exports. struct Export { static constexpr std::byte function{0x00}; static constexpr std::byte table{0x01}; diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index ed3b43a4408ac..a958a71e19018 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -5,6 +5,11 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// This file implements the WebAssembly importer. +// +//===----------------------------------------------------------------------===// + #include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -19,6 +24,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LEB128.h" +#include #include #include @@ -648,7 +654,7 @@ void ValueStack::dump() const { // end of the vector. Iterate in reverse so that the first thing we print // is the top of the stack. size_t stackSize = size(); - for (size_t idx = 0; idx < stackSize;) { + for (size_t idx = 0; idx < stackSize; idx++) { size_t actualIdx = stackSize - 1 - idx; llvm::dbgs() << " "; values[actualIdx].dump(); From 5ebf0e1731b7bda60494d277b9ebdffb6f8d1d1b Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Tue, 12 Aug 2025 17:15:48 +0900 Subject: [PATCH 11/14] [MLIR][WASM] Fix template to please MSVC --- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index a958a71e19018..4dc1d425188ad 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -106,7 +106,7 @@ template struct UniqueByte : ByteSequence {}; template -constexpr ByteSequence +constexpr ByteSequence byteSeqFromIntSeq(std::integer_sequence) { return {}; } From 819e2971a9701cb79b5eb3ce4d7f95689406e729 Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Wed, 13 Aug 2025 13:41:10 +0900 Subject: [PATCH 12/14] [MLIR][WASM] NFC remove deadcode in WasmImporter --- mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 4dc1d425188ad..d0fa70c26faee 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -1,4 +1,4 @@ -//===- TranslateFromWasm.cpp - Translating to C++ calls -------------------===// +//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -95,25 +95,10 @@ constexpr bool sectionShouldBeUnique(WasmSectionType secType) { template struct ByteSequence {}; -template -constexpr ByteSequence -operator+(ByteSequence, ByteSequence) { - return {}; -} - /// Template class for representing a byte sequence of only one byte template struct UniqueByte : ByteSequence {}; -template -constexpr ByteSequence -byteSeqFromIntSeq(std::integer_sequence) { - return {}; -} - -constexpr auto allOpCodes = - byteSeqFromIntSeq(std::make_integer_sequence()); - constexpr ByteSequence< WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64, WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64, From bff79f3af339387d8a8fa71869cf9b00aa859ecb Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Thu, 14 Aug 2025 10:06:14 +0900 Subject: [PATCH 13/14] [MLIR][WASM] NFC: fix typo and documentation --- mlir/include/mlir/Target/Wasm/WasmImporter.h | 9 +++------ mlir/lib/Target/Wasm/TranslateFromWasm.cpp | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h index 5cc42a1f32fa4..9b52f13e042df 100644 --- a/mlir/include/mlir/Target/Wasm/WasmImporter.h +++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h @@ -21,12 +21,9 @@ namespace mlir::wasm { -/// Translates the given operation to C++ code. The operation or operations in -/// the region of 'op' need almost all be in EmitC dialect. The parameter -/// 'declareVariablesAtTop' enforces that all variables for op results and block -/// arguments are declared at the beginning of the function. -/// If parameter 'fileId' is non-empty, then body of `emitc.file` ops -/// with matching id are emitted. +/// If `source` contains a valid Wasm binary file, this function returns a +/// a ModuleOp containing the representation of trhe Wasm module encoded in +/// the source file in the `wasmssa` dialect. OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context); } // namespace mlir::wasm diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index d0fa70c26faee..c23a2915ef18e 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -319,7 +319,7 @@ class ParserHead { LDBG() << " Current offset: " << offset; if (nBytes > size()) return emitError(getLocation(), "trying to extract ") - << nBytes << "bytes when only " << size() << "are avilables"; + << nBytes << "bytes when only " << size() << "are available"; StringRef res = head.slice(offset, offset + nBytes); offset += nBytes; From 3e2caec7eea3464921ef391953697b883eb31ddd Mon Sep 17 00:00:00 2001 From: Luc Forget Date: Thu, 14 Aug 2025 14:25:15 +0900 Subject: [PATCH 14/14] [MLIR][WASM] NFC: fix typo --- mlir/include/mlir/Target/Wasm/WasmImporter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h index 9b52f13e042df..3f05bbecefc8a 100644 --- a/mlir/include/mlir/Target/Wasm/WasmImporter.h +++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h @@ -22,7 +22,7 @@ namespace mlir::wasm { /// If `source` contains a valid Wasm binary file, this function returns a -/// a ModuleOp containing the representation of trhe Wasm module encoded in +/// a ModuleOp containing the representation of the Wasm module encoded in /// the source file in the `wasmssa` dialect. OwningOpRef importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context);