22
22
#include " llvm/ADT/StringExtras.h"
23
23
#include " llvm/ADT/StringRef.h"
24
24
#include " llvm/Support/Endian.h"
25
+ #include " llvm/Support/Format.h"
26
+ #include " llvm/Support/LogicalResult.h"
25
27
#include " llvm/Support/MemoryBufferRef.h"
26
28
#include " llvm/Support/SourceMgr.h"
27
29
28
30
#include < cstddef>
31
+ #include < cstdint>
29
32
#include < list>
30
33
#include < memory>
31
34
#include < numeric>
@@ -111,6 +114,9 @@ class EncodingReader {
111
114
};
112
115
113
116
// Shift the reader position to the next alignment boundary.
117
+ // Note: this assumes the pointer alignment matches the alignment of the
118
+ // data from the start of the buffer. In other words, this code is only
119
+ // valid if `dataIt` is offsetting into an already aligned buffer.
114
120
while (isUnaligned (dataIt)) {
115
121
uint8_t padding;
116
122
if (failed (parseByte (padding)))
@@ -258,9 +264,13 @@ class EncodingReader {
258
264
return success ();
259
265
}
260
266
267
+ // / Validate that the alignment requested in the section is valid.
268
+ using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
269
+
261
270
// / Parse a section header, placing the kind of section in `sectionID` and the
262
271
// / contents of the section in `sectionData`.
263
272
LogicalResult parseSection (bytecode::Section::ID §ionID,
273
+ ValidateAlignmentFn alignmentValidator,
264
274
ArrayRef<uint8_t > §ionData) {
265
275
uint8_t sectionIDAndHasAlignment;
266
276
uint64_t length;
@@ -281,8 +291,22 @@ class EncodingReader {
281
291
282
292
// Process the section alignment if present.
283
293
if (hasAlignment) {
294
+ // Read the requested alignment from the bytecode parser.
284
295
uint64_t alignment;
285
- if (failed (parseVarInt (alignment)) || failed (alignTo (alignment)))
296
+ if (failed (parseVarInt (alignment)))
297
+ return failure ();
298
+
299
+ // Check that the requested alignment is less than or equal to the
300
+ // alignment of the root buffer. If it is not, we cannot safely guarantee
301
+ // that the specified alignment is globally correct.
302
+ //
303
+ // E.g. if the buffer is 8k aligned and the section is 16k aligned,
304
+ // we could end up at an offset of 24k, which is not globally 16k aligned.
305
+ if (failed (alignmentValidator (alignment)))
306
+ return emitError (" failed to align section ID: " , unsigned (sectionID));
307
+
308
+ // Align the buffer.
309
+ if (failed (alignTo (alignment)))
286
310
return failure ();
287
311
}
288
312
@@ -1396,6 +1420,29 @@ class mlir::BytecodeReader::Impl {
1396
1420
return success ();
1397
1421
}
1398
1422
1423
+ LogicalResult checkSectionAlignment (
1424
+ unsigned alignment,
1425
+ function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
1426
+ // Check that the bytecode buffer meets the requested section alignment.
1427
+ //
1428
+ // If it does not, the virtual address of the item in the section will
1429
+ // not be aligned to the requested alignment.
1430
+ //
1431
+ // The typical case where this is necessary is the resource blob
1432
+ // optimization in `parseAsBlob` where we reference the weights from the
1433
+ // provided buffer instead of copying them to a new allocation.
1434
+ const bool isGloballyAligned =
1435
+ ((uintptr_t )buffer.getBufferStart () & (alignment - 1 )) == 0 ;
1436
+
1437
+ if (!isGloballyAligned)
1438
+ return emitError (" expected section alignment " )
1439
+ << alignment << " but bytecode buffer 0x"
1440
+ << Twine::utohexstr ((uint64_t )buffer.getBufferStart ())
1441
+ << " is not aligned" ;
1442
+
1443
+ return success ();
1444
+ };
1445
+
1399
1446
// / Return the context for this config.
1400
1447
MLIRContext *getContext () const { return config.getContext (); }
1401
1448
@@ -1506,7 +1553,7 @@ class mlir::BytecodeReader::Impl {
1506
1553
UseListOrderStorage (bool isIndexPairEncoding,
1507
1554
SmallVector<unsigned , 4 > &&indices)
1508
1555
: indices(std::move(indices)),
1509
- isIndexPairEncoding (isIndexPairEncoding){};
1556
+ isIndexPairEncoding (isIndexPairEncoding) {};
1510
1557
// / The vector containing the information required to reorder the
1511
1558
// / use-list of a value.
1512
1559
SmallVector<unsigned , 4 > indices;
@@ -1651,14 +1698,20 @@ LogicalResult BytecodeReader::Impl::read(
1651
1698
return failure ();
1652
1699
});
1653
1700
1701
+ const auto checkSectionAlignment = [&](unsigned alignment) {
1702
+ return this ->checkSectionAlignment (
1703
+ alignment, [&](const auto &msg) { return reader.emitError (msg); });
1704
+ };
1705
+
1654
1706
// Parse the raw data for each of the top-level sections of the bytecode.
1655
1707
std::optional<ArrayRef<uint8_t >>
1656
1708
sectionDatas[bytecode::Section::kNumSections ];
1657
1709
while (!reader.empty ()) {
1658
1710
// Read the next section from the bytecode.
1659
1711
bytecode::Section::ID sectionID;
1660
1712
ArrayRef<uint8_t > sectionData;
1661
- if (failed (reader.parseSection (sectionID, sectionData)))
1713
+ if (failed (
1714
+ reader.parseSection (sectionID, checkSectionAlignment, sectionData)))
1662
1715
return failure ();
1663
1716
1664
1717
// Check for duplicate sections, we only expect one instance of each.
@@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1778
1831
return failure ();
1779
1832
dialects.resize (numDialects);
1780
1833
1834
+ const auto checkSectionAlignment = [&](unsigned alignment) {
1835
+ return this ->checkSectionAlignment (alignment, [&](const auto &msg) {
1836
+ return sectionReader.emitError (msg);
1837
+ });
1838
+ };
1839
+
1781
1840
// Parse each of the dialects.
1782
1841
for (uint64_t i = 0 ; i < numDialects; ++i) {
1783
1842
dialects[i] = std::make_unique<BytecodeDialect>();
@@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1800
1859
return failure ();
1801
1860
if (versionAvailable) {
1802
1861
bytecode::Section::ID sectionID;
1803
- if (failed (sectionReader.parseSection (sectionID,
1862
+ if (failed (sectionReader.parseSection (sectionID, checkSectionAlignment,
1804
1863
dialects[i]->versionBuffer )))
1805
1864
return failure ();
1806
1865
if (sectionID != bytecode::Section::kDialectVersions ) {
@@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2121
2180
LogicalResult
2122
2181
BytecodeReader::Impl::parseRegions (std::vector<RegionReadState> ®ionStack,
2123
2182
RegionReadState &readState) {
2183
+ const auto checkSectionAlignment = [&](unsigned alignment) {
2184
+ return this ->checkSectionAlignment (
2185
+ alignment, [&](const auto &msg) { return emitError (fileLoc, msg); });
2186
+ };
2187
+
2124
2188
// Process regions, blocks, and operations until the end or if a nested
2125
2189
// region is encountered. In this case we push a new state in regionStack and
2126
2190
// return, the processing of the current region will resume afterward.
@@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
2161
2225
if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2162
2226
bytecode::Section::ID sectionID;
2163
2227
ArrayRef<uint8_t > sectionData;
2164
- if (failed (reader.parseSection (sectionID, sectionData)))
2228
+ if (failed (reader.parseSection (sectionID, checkSectionAlignment,
2229
+ sectionData)))
2165
2230
return failure ();
2166
2231
if (sectionID != bytecode::Section::kIR )
2167
2232
return emitError (fileLoc, " expected IR section for region" );
0 commit comments