diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index 2b4697434717d..cc5aaed416512 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -613,6 +613,9 @@ class RawEmitterOstream : public raw_ostream { } // namespace void EncodingEmitter::writeTo(raw_ostream &os) const { + // Reserve space in the ostream for the encoded contents. + os.reserveExtraSpace(size()); + for (auto &prevResult : prevResultList) os.write((const char *)prevResult.data(), prevResult.size()); os.write((const char *)currentResult.data(), currentResult.size()); diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp index cb915a092a0be..c036fe26b1b36 100644 --- a/mlir/unittests/Bytecode/BytecodeTest.cpp +++ b/mlir/unittests/Bytecode/BytecodeTest.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" #include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -37,6 +38,29 @@ module @TestDialectResources attributes { #-} )"; +struct MockOstream final : public raw_ostream { + std::unique_ptr buffer; + size_t size = 0; + + MOCK_METHOD(void, reserveExtraSpace, (uint64_t extraSpace), (override)); + + MockOstream() : raw_ostream(true) {} + uint64_t current_pos() const override { return pos; } + +private: + size_t pos = 0; + + void write_impl(const char *ptr, size_t length) override { + if (pos + length <= size) { + memcpy((void *)(buffer.get() + pos), ptr, length); + pos += length; + } else { + report_fatal_error( + "Attempted to write past the end of the fixed size buffer."); + } + } +}; + TEST(Bytecode, MultiModuleWithResource) { MLIRContext context; Builder builder(&context); @@ -45,12 +69,17 @@ TEST(Bytecode, MultiModuleWithResource) { parseSourceString(irWithResources, parseConfig); ASSERT_TRUE(module); - // Write the module to bytecode - std::string buffer; - llvm::raw_string_ostream ostream(buffer); + // Write the module to bytecode. + MockOstream ostream; + EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) { + ostream.buffer = std::make_unique(space); + ostream.size = space; + }); ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream))); // Create copy of buffer which is aligned to requested resource alignment. + std::string buffer((char *)ostream.buffer.get(), + (char *)ostream.buffer.get() + ostream.size); constexpr size_t kAlignment = 0x20; size_t bufferSize = buffer.size(); buffer.reserve(bufferSize + kAlignment - 1);