Skip to content

Commit 8106c81

Browse files
authored
[MLIR][Bytecode] Enforce alignment requirements (#157004)
Adds a check that the bytecode buffer is aligned to any section alignment requirements. Without this check, if the source buffer is not sufficiently aligned, we may return early when aligning the data pointer. In that case, we may end up trying to read successive sections from an incorrect offset, giving the appearance of invalid bytecode. This requirement is documented in the bytecode unit tests, but is not otherwise documented in the code or bytecode reference.
1 parent ffbd616 commit 8106c81

File tree

3 files changed

+124
-6
lines changed

3 files changed

+124
-6
lines changed

mlir/docs/BytecodeFormat.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ lazy-loading, and more. Each section contains a Section ID, whose high bit
125125
indicates if the section has alignment requirements, a length (which allows for
126126
skipping over the section), and an optional alignment. When an alignment is
127127
present, a variable number of padding bytes (0xCB) may appear before the section
128-
data. The alignment of a section must be a power of 2.
128+
data. The alignment of a section must be a power of 2. The input bytecode buffer must satisfy the same alignment requirements as those of every section.
129129

130130
## MLIR Encoding
131131

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
#include "llvm/ADT/StringExtras.h"
2323
#include "llvm/ADT/StringRef.h"
2424
#include "llvm/Support/Endian.h"
25+
#include "llvm/Support/Format.h"
26+
#include "llvm/Support/LogicalResult.h"
2527
#include "llvm/Support/MemoryBufferRef.h"
2628
#include "llvm/Support/SourceMgr.h"
2729

2830
#include <cstddef>
31+
#include <cstdint>
2932
#include <list>
3033
#include <memory>
3134
#include <numeric>
@@ -111,6 +114,9 @@ class EncodingReader {
111114
};
112115

113116
// Shift the reader position to the next alignment boundary.
117+
// Note: this assumes the pointer alignment matches the alignment of the
118+
// data from the start of the buffer. In other words, this code is only
119+
// valid if `dataIt` is offsetting into an already aligned buffer.
114120
while (isUnaligned(dataIt)) {
115121
uint8_t padding;
116122
if (failed(parseByte(padding)))
@@ -258,9 +264,13 @@ class EncodingReader {
258264
return success();
259265
}
260266

267+
/// Validate that the alignment requested in the section is valid.
268+
using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
269+
261270
/// Parse a section header, placing the kind of section in `sectionID` and the
262271
/// contents of the section in `sectionData`.
263272
LogicalResult parseSection(bytecode::Section::ID &sectionID,
273+
ValidateAlignmentFn alignmentValidator,
264274
ArrayRef<uint8_t> &sectionData) {
265275
uint8_t sectionIDAndHasAlignment;
266276
uint64_t length;
@@ -281,8 +291,22 @@ class EncodingReader {
281291

282292
// Process the section alignment if present.
283293
if (hasAlignment) {
294+
// Read the requested alignment from the bytecode parser.
284295
uint64_t alignment;
285-
if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
296+
if (failed(parseVarInt(alignment)))
297+
return failure();
298+
299+
// Check that the requested alignment is less than or equal to the
300+
// alignment of the root buffer. If it is not, we cannot safely guarantee
301+
// that the specified alignment is globally correct.
302+
//
303+
// E.g. if the buffer is 8k aligned and the section is 16k aligned,
304+
// we could end up at an offset of 24k, which is not globally 16k aligned.
305+
if (failed(alignmentValidator(alignment)))
306+
return emitError("failed to align section ID: ", unsigned(sectionID));
307+
308+
// Align the buffer.
309+
if (failed(alignTo(alignment)))
286310
return failure();
287311
}
288312

@@ -1396,6 +1420,29 @@ class mlir::BytecodeReader::Impl {
13961420
return success();
13971421
}
13981422

1423+
LogicalResult checkSectionAlignment(
1424+
unsigned alignment,
1425+
function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
1426+
// Check that the bytecode buffer meets the requested section alignment.
1427+
//
1428+
// If it does not, the virtual address of the item in the section will
1429+
// not be aligned to the requested alignment.
1430+
//
1431+
// The typical case where this is necessary is the resource blob
1432+
// optimization in `parseAsBlob` where we reference the weights from the
1433+
// provided buffer instead of copying them to a new allocation.
1434+
const bool isGloballyAligned =
1435+
((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1436+
1437+
if (!isGloballyAligned)
1438+
return emitError("expected section alignment ")
1439+
<< alignment << " but bytecode buffer 0x"
1440+
<< Twine::utohexstr((uint64_t)buffer.getBufferStart())
1441+
<< " is not aligned";
1442+
1443+
return success();
1444+
};
1445+
13991446
/// Return the context for this config.
14001447
MLIRContext *getContext() const { return config.getContext(); }
14011448

@@ -1506,7 +1553,7 @@ class mlir::BytecodeReader::Impl {
15061553
UseListOrderStorage(bool isIndexPairEncoding,
15071554
SmallVector<unsigned, 4> &&indices)
15081555
: indices(std::move(indices)),
1509-
isIndexPairEncoding(isIndexPairEncoding){};
1556+
isIndexPairEncoding(isIndexPairEncoding) {};
15101557
/// The vector containing the information required to reorder the
15111558
/// use-list of a value.
15121559
SmallVector<unsigned, 4> indices;
@@ -1651,14 +1698,20 @@ LogicalResult BytecodeReader::Impl::read(
16511698
return failure();
16521699
});
16531700

1701+
const auto checkSectionAlignment = [&](unsigned alignment) {
1702+
return this->checkSectionAlignment(
1703+
alignment, [&](const auto &msg) { return reader.emitError(msg); });
1704+
};
1705+
16541706
// Parse the raw data for each of the top-level sections of the bytecode.
16551707
std::optional<ArrayRef<uint8_t>>
16561708
sectionDatas[bytecode::Section::kNumSections];
16571709
while (!reader.empty()) {
16581710
// Read the next section from the bytecode.
16591711
bytecode::Section::ID sectionID;
16601712
ArrayRef<uint8_t> sectionData;
1661-
if (failed(reader.parseSection(sectionID, sectionData)))
1713+
if (failed(
1714+
reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
16621715
return failure();
16631716

16641717
// Check for duplicate sections, we only expect one instance of each.
@@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
17781831
return failure();
17791832
dialects.resize(numDialects);
17801833

1834+
const auto checkSectionAlignment = [&](unsigned alignment) {
1835+
return this->checkSectionAlignment(alignment, [&](const auto &msg) {
1836+
return sectionReader.emitError(msg);
1837+
});
1838+
};
1839+
17811840
// Parse each of the dialects.
17821841
for (uint64_t i = 0; i < numDialects; ++i) {
17831842
dialects[i] = std::make_unique<BytecodeDialect>();
@@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
18001859
return failure();
18011860
if (versionAvailable) {
18021861
bytecode::Section::ID sectionID;
1803-
if (failed(sectionReader.parseSection(sectionID,
1862+
if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
18041863
dialects[i]->versionBuffer)))
18051864
return failure();
18061865
if (sectionID != bytecode::Section::kDialectVersions) {
@@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
21212180
LogicalResult
21222181
BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
21232182
RegionReadState &readState) {
2183+
const auto checkSectionAlignment = [&](unsigned alignment) {
2184+
return this->checkSectionAlignment(
2185+
alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
2186+
};
2187+
21242188
// Process regions, blocks, and operations until the end or if a nested
21252189
// region is encountered. In this case we push a new state in regionStack and
21262190
// return, the processing of the current region will resume afterward.
@@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
21612225
if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
21622226
bytecode::Section::ID sectionID;
21632227
ArrayRef<uint8_t> sectionData;
2164-
if (failed(reader.parseSection(sectionID, sectionData)))
2228+
if (failed(reader.parseSection(sectionID, checkSectionAlignment,
2229+
sectionData)))
21652230
return failure();
21662231
if (sectionID != bytecode::Section::kIR)
21672232
return emitError(fileLoc, "expected IR section for region");

mlir/unittests/Bytecode/BytecodeTest.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
#include "mlir/Bytecode/BytecodeWriter.h"
1111
#include "mlir/IR/AsmState.h"
1212
#include "mlir/IR/BuiltinAttributes.h"
13+
#include "mlir/IR/Diagnostics.h"
1314
#include "mlir/IR/OpImplementation.h"
1415
#include "mlir/IR/OwningOpRef.h"
1516
#include "mlir/Parser/Parser.h"
1617

1718
#include "llvm/ADT/StringRef.h"
19+
#include "llvm/Support/Alignment.h"
1820
#include "llvm/Support/Endian.h"
1921
#include "llvm/Support/MemoryBufferRef.h"
2022
#include "llvm/Support/raw_ostream.h"
@@ -117,6 +119,57 @@ TEST(Bytecode, MultiModuleWithResource) {
117119
checkResourceAttribute(*roundTripModule);
118120
}
119121

122+
TEST(Bytecode, AlignmentFailure) {
123+
MLIRContext context;
124+
Builder builder(&context);
125+
ParserConfig parseConfig(&context);
126+
OwningOpRef<Operation *> module =
127+
parseSourceString<Operation *>(irWithResources, parseConfig);
128+
ASSERT_TRUE(module);
129+
130+
// Write the module to bytecode.
131+
MockOstream ostream;
132+
EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) {
133+
ostream.buffer = std::make_unique<std::byte[]>(space);
134+
ostream.size = space;
135+
});
136+
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
137+
138+
// Create copy of buffer which is not aligned to requested resource alignment.
139+
std::string buffer((char *)ostream.buffer.get(),
140+
(char *)ostream.buffer.get() + ostream.size);
141+
size_t bufferSize = buffer.size();
142+
143+
// Increment into the buffer until we get to a power of 2 alignment that is
144+
// not 32 bit aligned.
145+
size_t pad = 0;
146+
while (true) {
147+
if (llvm::isAddrAligned(Align(2), &buffer[pad]) &&
148+
!llvm::isAddrAligned(Align(32), &buffer[pad]))
149+
break;
150+
151+
pad++;
152+
buffer.reserve(bufferSize + pad);
153+
}
154+
155+
buffer.insert(0, pad, ' ');
156+
StringRef alignedBuffer(buffer.data() + pad, bufferSize);
157+
158+
// Attach a diagnostic handler to get the error message.
159+
llvm::SmallVector<std::string> msg;
160+
ScopedDiagnosticHandler handler(
161+
&context, [&msg](Diagnostic &diag) { msg.push_back(diag.str()); });
162+
163+
// Parse it back
164+
OwningOpRef<Operation *> roundTripModule =
165+
parseSourceString<Operation *>(alignedBuffer, parseConfig);
166+
ASSERT_FALSE(roundTripModule);
167+
ASSERT_THAT(msg[0].data(), ::testing::StartsWith(
168+
"expected section alignment 32 but bytecode "
169+
"buffer"));
170+
ASSERT_STREQ(msg[1].data(), "failed to align section ID: 5");
171+
}
172+
120173
namespace {
121174
/// A custom operation for the purpose of showcasing how discardable attributes
122175
/// are handled in absence of properties.

0 commit comments

Comments
 (0)