Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llvm/include/llvm/Support/raw_ostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef LLVM_SUPPORT_RAW_OSTREAM_H
#define LLVM_SUPPORT_RAW_OSTREAM_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DataTypes.h"
Expand Down Expand Up @@ -769,6 +770,18 @@ class buffer_unique_ostream : public raw_svector_ostream {
~buffer_unique_ostream() override { *OS << str(); }
};

// Creates an output stream with a fixed size buffer.
class fixed_buffer_ostream : public raw_ostream {
MutableArrayRef<std::byte> Buffer;
size_t Pos = 0;

void write_impl(const char *Ptr, size_t Size) final;
uint64_t current_pos() const final { return Pos; }

public:
fixed_buffer_ostream(MutableArrayRef<std::byte> Buffer);
};

// Helper struct to add indentation to raw_ostream. Instead of
// OS.indent(6) << "more stuff";
// you can use
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Support/raw_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,19 @@ void buffer_ostream::anchor() {}

void buffer_unique_ostream::anchor() {}

void fixed_buffer_ostream::write_impl(const char *Ptr, size_t Size) {
if (Pos + Size <= Buffer.size()) {
memcpy((void *)(Buffer.data() + Pos), Ptr, Size);
Pos += Size;
} else {
report_fatal_error(
"Attempted to write past the end of the fixed size buffer.");
}
}

fixed_buffer_ostream::fixed_buffer_ostream(MutableArrayRef<std::byte> Buffer)
: raw_ostream(true), Buffer{Buffer} {}

Error llvm::writeToOutput(StringRef OutputFileName,
std::function<Error(raw_ostream &)> Write) {
if (OutputFileName == "-")
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Bytecode/BytecodeWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ class BytecodeWriterConfig {
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config = {});

/// Writes the bytecode for the given operation to a memory-mapped buffer.
/// It only ever fails if setDesiredByteCodeVersion can't be honored.
/// Returns nullptr on failure.
std::shared_ptr<ArrayRef<std::byte>>
writeBytecode(Operation *op, const BytecodeWriterConfig &config = {});

} // namespace mlir

#endif // MLIR_BYTECODE_BYTECODEWRITER_H
71 changes: 63 additions & 8 deletions mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/Memory.h"
#include "llvm/Support/raw_ostream.h"
#include <cstddef>
#include <optional>
#include <system_error>

#define DEBUG_TYPE "mlir-bytecode-writer"

Expand Down Expand Up @@ -652,7 +655,7 @@ class BytecodeWriter {
propertiesSection(numberingState, stringSection, config.getImpl()) {}

/// Write the bytecode for the given root operation.
LogicalResult write(Operation *rootOp, raw_ostream &os);
LogicalResult writeInto(Operation *rootOp, EncodingEmitter &emitter);

private:
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -718,9 +721,8 @@ class BytecodeWriter {
};
} // namespace

LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
EncodingEmitter emitter;

LogicalResult BytecodeWriter::writeInto(Operation *rootOp,
EncodingEmitter &emitter) {
// Emit the bytecode file header. This is how we identify the output as a
// bytecode file.
emitter.emitString("ML\xefR", "bytecode header");
Expand Down Expand Up @@ -761,9 +763,6 @@ LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
return rootOp->emitError(
"unexpected properties emitted incompatible with bytecode <5");

// Write the generated bytecode to the provided output stream.
emitter.writeTo(os);

return success();
}

Expand Down Expand Up @@ -1348,5 +1347,61 @@ void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) {
LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
BytecodeWriter writer(op, config);
return writer.write(op, os);
EncodingEmitter emitter;

if (succeeded(writer.writeInto(op, emitter))) {
emitter.writeTo(os);
return success();
}

return failure();
}

namespace {
struct MemoryMappedBlock {
static std::shared_ptr<MemoryMappedBlock>
createMemoryMappedBlock(size_t numBytes) {
auto instance = std::make_shared<MemoryMappedBlock>();

std::error_code ec;
instance->mmapBlock =
llvm::sys::OwningMemoryBlock{llvm::sys::Memory::allocateMappedMemory(
numBytes, nullptr, llvm::sys::Memory::MF_WRITE, ec)};
if (ec)
return nullptr;

instance->writableView = MutableArrayRef<std::byte>(
(std::byte *)instance->mmapBlock.base(), numBytes);

return instance;
}

llvm::sys::OwningMemoryBlock mmapBlock;
MutableArrayRef<std::byte> writableView;
};
} // namespace

std::shared_ptr<ArrayRef<std::byte>>
mlir::writeBytecode(Operation *op, const BytecodeWriterConfig &config) {
BytecodeWriter writer(op, config);
EncodingEmitter emitter;
if (succeeded(writer.writeInto(op, emitter))) {
// Allocate a new memory block for the emitter to write into.
auto block = MemoryMappedBlock::createMemoryMappedBlock(emitter.size());
if (!block)
return nullptr;

// Wrap the block in an output stream.
llvm::fixed_buffer_ostream stream(block->writableView);
emitter.writeTo(stream);

// Write protect the block.
if (llvm::sys::Memory::protectMappedMemory(
block->mmapBlock.getMemoryBlock(), llvm::sys::Memory::MF_READ))
return nullptr;

return std::shared_ptr<ArrayRef<std::byte>>(block, &block->writableView);
}

return nullptr;
}
22 changes: 22 additions & 0 deletions mlir/unittests/Bytecode/BytecodeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <cstring>

using namespace llvm;
using namespace mlir;
Expand Down Expand Up @@ -88,6 +90,26 @@ TEST(Bytecode, MultiModuleWithResource) {
checkResourceAttribute(*roundTripModule);
}

TEST(Bytecode, WriteEquivalence) {
MLIRContext context;
Builder builder(&context);
ParserConfig parseConfig(&context);
OwningOpRef<Operation *> module =
parseSourceString<Operation *>(irWithResources, parseConfig);
ASSERT_TRUE(module);

// Write the module to bytecode
std::string buffer;
llvm::raw_string_ostream ostream(buffer);
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));

// Write the module to bytecode using the mmap API.
auto writeBuffer = writeBytecode(module.get());
ASSERT_TRUE(writeBuffer);
ASSERT_EQ(writeBuffer->size(), buffer.size());
ASSERT_EQ(memcmp(buffer.data(), writeBuffer->data(), writeBuffer->size()), 0);
}

namespace {
/// A custom operation for the purpose of showcasing how discardable attributes
/// are handled in absence of properties.
Expand Down