Skip to content

Commit 0e3d6d3

Browse files
pytorchbotlucylq
andauthored
Add segment_data_size to extended header (#14226)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14091 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/107/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/107/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/107/orig @diff-train-skip-merge --------- Co-authored-by: lucylq <[email protected]>
1 parent 8358516 commit 0e3d6d3

File tree

8 files changed

+177
-18
lines changed

8 files changed

+177
-18
lines changed

docs/source/pte-file-format.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ Optional extended header:
7171
| byte offset zero above. I.e., it includes these headers.
7272
| [24..31] uint64_t offset (from byte offset zero above) to the start of the
7373
| first segment, or zero if there are no segments.
74-
| [31..?] Any zero-padding necessary to preserve the alignment of the data
74+
| [32..39] uint64_t size of the segment data, ie. the size from the segment_base_offset
75+
| to the end of the segments. Note, the last segment should not have any
76+
| trailing padding.
77+
| [40..?] Any zero-padding necessary to preserve the alignment of the data
7578
| that follows.
7679
End of optional extended header.
7780
```
@@ -81,13 +84,16 @@ Example:
8184
Offset to flatbuffer root (0x38)
8285
| File magic ("ET??")
8386
| | Extended header magic ("eh??")
84-
| | | Extended header size (0x18)
87+
| | | Extended header size (0x20)
8588
vvvvvvvvvvv vvvvvvvvvvv vvvvvvvvvvv vvvvvvvvvvv
86-
0x0000 38 00 00 00 45 54 3F 3F 65 68 3F 3F 18 00 00 00
89+
0x0000 38 00 00 00 45 54 3F 3F 65 68 3F 3F 20 00 00 00
8790
0x0010 F0 02 00 00 00 00 00 00 00 10 00 00 00 00 00 00
91+
0x0020 20
8892
^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^
8993
| Offset to segments (0x1000)
9094
Size of program flatbuffer data (0x2f0)
95+
|
96+
Segment data size (0x20)
9197
```
9298

9399
## Program data

exir/_serialize/_program.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ class _ExtendedHeader:
146146
+ 8
147147
# Segment base offset
148148
+ 8
149+
# Segment data size
150+
+ 8
149151
)
150152

151153
# Instance attributes. @dataclass will turn these into ctor args.
@@ -155,6 +157,9 @@ class _ExtendedHeader:
155157
# Offset to the start of the first segment, or zero if there
156158
# are no segments.
157159
segment_base_offset: int
160+
# Size of the segment data, in bytes, or zero if there are no segments, or
161+
# if the this field isn't populated in the PTE file.
162+
segment_data_size: int
158163

159164
# The magic bytes read from or to be written to the binary header.
160165
magic: bytes = EXPECTED_MAGIC
@@ -189,6 +194,7 @@ def from_bytes(data: bytes) -> "_ExtendedHeader":
189194
segment_base_offset=int.from_bytes(
190195
data[16:24], byteorder=_HEADER_BYTEORDER
191196
),
197+
segment_data_size=int.from_bytes(data[24:32], byteorder=_HEADER_BYTEORDER),
192198
)
193199

194200
def is_valid(self) -> bool:
@@ -220,6 +226,9 @@ def to_bytes(self) -> bytes:
220226
# uint64_t: Offset to the start of the first segment, or zero if
221227
# there are no segments.
222228
+ self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
229+
# uint64_t: size of the segment data, or zero if there are no
230+
# segments.
231+
+ self.segment_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
223232
)
224233
return data
225234

@@ -512,7 +521,9 @@ def serialize_pte_binary(
512521

513522
# Construct and pad the extended header.
514523
header_data: bytes = _ExtendedHeader(
515-
program_size=program_size, segment_base_offset=segment_base_offset
524+
program_size=program_size,
525+
segment_base_offset=segment_base_offset,
526+
segment_data_size=len(segments_data),
516527
).to_bytes()
517528
header_data = pad_to(header_data, padded_header_length)
518529

exir/_serialize/test/test_program.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def constant_segment_with_tensor_alignment(
191191
# the end of the file.
192192
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
193193
self.assertLess(eh.segment_base_offset, len(pte_data))
194+
# Segment data_size should be non-zero since there are segments.
195+
self.assertGreater(eh.segment_data_size, 0)
194196

195197
# Peek inside the actual flatbuffer data to see the segments.
196198
program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
@@ -232,6 +234,8 @@ def constant_segment_with_tensor_alignment(
232234
# Check segment data.
233235
offsets = subsegment_offsets.offsets
234236
segment_data: bytes = pte_data[eh.segment_base_offset :]
237+
# Check segment data size.
238+
self.assertEqual(len(segment_data), eh.segment_data_size)
235239

236240
# tensor[1]: padding.
237241
self.assertEqual(
@@ -514,6 +518,8 @@ def test_round_trip_with_segments(self) -> None:
514518
# the end of the file.
515519
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
516520
self.assertLess(eh.segment_base_offset, len(pte_data))
521+
# Segment data size should be non-zero since there are segments.
522+
self.assertGreater(eh.segment_data_size, 0)
517523

518524
# Peek inside the actual flatbuffer data to see the segments. Note that
519525
# this also implicity tests the case where we try parsing the entire
@@ -566,6 +572,8 @@ def test_round_trip_with_segments(self) -> None:
566572
# Now that we've shown that the base offset is correct, slice off the
567573
# front so that all segment offsets are relative to zero.
568574
segment_data: bytes = pte_data[segment_base_offset:]
575+
# Check segment data size.
576+
self.assertEqual(len(segment_data), eh.segment_data_size)
569577

570578
# End of the first segment. It's much smaller than the alignment,
571579
# so we know that it's followed by zeros.
@@ -729,6 +737,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
729737
# the end of the file.
730738
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
731739
self.assertLess(eh.segment_base_offset, len(pte_data))
740+
# Segment data size should be non-zero since there are segments.
741+
self.assertGreater(eh.segment_data_size, 0)
732742

733743
# Peek inside the actual flatbuffer data to see the segments.
734744
program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
@@ -811,6 +821,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
811821
# Now that we've shown that the base offset is correct, slice off the
812822
# front so that all segment offsets are relative to zero.
813823
segment_data: bytes = pte_data[segment_base_offset:]
824+
# Check segment data size.
825+
self.assertEqual(len(segment_data), eh.segment_data_size)
814826

815827
# Check segment[0] for constants.
816828
offsets = subsegment_offsets.offsets
@@ -925,6 +937,8 @@ def test_named_data_segments(self) -> None:
925937
# the end of the file.
926938
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
927939
self.assertLess(eh.segment_base_offset, len(pte_data))
940+
# Segment data size should be non-zero since there are segments.
941+
self.assertGreater(eh.segment_data_size, 0)
928942

929943
# Peek inside the actual flatbuffer data to see the named data segments.
930944
program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
@@ -958,6 +972,9 @@ def test_named_data_segments(self) -> None:
958972

959973
# Check the pte data for buffer values.
960974
segment_data: bytes = pte_data[eh.segment_base_offset :]
975+
# Check segment data size.
976+
self.assertEqual(len(segment_data), eh.segment_data_size)
977+
961978
self.assertEqual(
962979
segment_data[
963980
segment_table[0].offset : segment_table[0].offset
@@ -985,18 +1002,21 @@ def test_named_data_segments(self) -> None:
9851002
# the example data.
9861003
EXAMPLE_PROGRAM_SIZE: int = 0x1122112233443344
9871004
EXAMPLE_SEGMENT_BASE_OFFSET: int = 0x5566556677887788
1005+
EXAMPLE_SEGMENT_DATA_SIZE: int = 0x5544554433223322
9881006
# This data is intentionally fragile. If the header layout or magic changes,
9891007
# this test must change too. The layout of the header is a contract, not an
9901008
# implementation detail.
9911009
EXAMPLE_HEADER_DATA: bytes = (
9921010
# Magic bytes
9931011
b"eh00"
9941012
# uint32_t header size (little endian)
995-
+ b"\x18\x00\x00\x00"
1013+
+ b"\x20\x00\x00\x00"
9961014
# uint64_t program size
9971015
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
9981016
# uint64_t segment base offset
9991017
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
1018+
# uint64_t segment data size
1019+
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
10001020
)
10011021

10021022

@@ -1005,6 +1025,7 @@ def test_to_bytes(self) -> None:
10051025
eh = _ExtendedHeader(
10061026
program_size=EXAMPLE_PROGRAM_SIZE,
10071027
segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET,
1028+
segment_data_size=EXAMPLE_SEGMENT_DATA_SIZE,
10081029
)
10091030
self.assertTrue(eh.is_valid())
10101031
self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA)
@@ -1013,6 +1034,7 @@ def test_to_bytes_with_non_defaults(self) -> None:
10131034
eh = _ExtendedHeader(
10141035
program_size=EXAMPLE_PROGRAM_SIZE,
10151036
segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET,
1037+
segment_data_size=EXAMPLE_SEGMENT_DATA_SIZE,
10161038
# Override the default magic and length, to demonstrate that this
10171039
# does not affect the serialized header.
10181040
magic=b"ABCD",
@@ -1036,6 +1058,7 @@ def test_from_bytes_valid(self) -> None:
10361058
self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
10371059
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
10381060
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
1061+
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)
10391062

10401063
def test_from_bytes_with_more_data_than_necessary(self) -> None:
10411064
# Pass in more data than necessary to parse the header.
@@ -1049,6 +1072,7 @@ def test_from_bytes_with_more_data_than_necessary(self) -> None:
10491072
self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
10501073
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
10511074
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
1075+
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)
10521076

10531077
def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
10541078
# Simulate a backwards-compatibility situation. Parse a header
@@ -1059,11 +1083,13 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
10591083
# Magic bytes
10601084
b"eh00"
10611085
# uint32_t header size (little endian)
1062-
+ b"\x1c\x00\x00\x00" # Longer than expected
1086+
+ b"\x21\x00\x00\x00" # Longer than expected
10631087
# uint64_t program size
10641088
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
10651089
# uint64_t segment base offset
10661090
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
1091+
# uint64_t segment data size
1092+
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
10671093
# uint32_t new field (ignored)
10681094
+ b"\xff\xee\xff\xee"
10691095
)
@@ -1075,9 +1101,10 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
10751101
self.assertTrue(eh.is_valid())
10761102

10771103
self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC)
1078-
self.assertEqual(eh.length, 28)
1104+
self.assertEqual(eh.length, 33)
10791105
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
10801106
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
1107+
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)
10811108

10821109
def test_from_bytes_not_enough_data_fails(self) -> None:
10831110
# Parsing a truncated prefix should fail.
@@ -1090,11 +1117,13 @@ def test_from_bytes_invalid_magic(self) -> None:
10901117
# Magic bytes
10911118
b"ABCD" # Invalid
10921119
# uint32_t header size (little endian)
1093-
+ b"\x18\x00\x00\x00"
1120+
+ b"\x20\x00\x00\x00"
10941121
# uint64_t program size
10951122
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
10961123
# uint64_t segment base offset
10971124
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
1125+
# uint64_t segment data size
1126+
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
10981127
)
10991128

11001129
# Parse the serialized extended header.
@@ -1109,6 +1138,7 @@ def test_from_bytes_invalid_magic(self) -> None:
11091138
self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
11101139
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
11111140
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
1141+
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)
11121142

11131143
def test_from_bytes_invalid_length(self) -> None:
11141144
# An invalid serialized header
@@ -1121,6 +1151,8 @@ def test_from_bytes_invalid_length(self) -> None:
11211151
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
11221152
# uint64_t segment base offset
11231153
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
1154+
# uint64_t segment data size
1155+
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
11241156
)
11251157

11261158
# Parse the serialized extended header.
@@ -1135,3 +1167,4 @@ def test_from_bytes_invalid_length(self) -> None:
11351167
self.assertEqual(eh.length, 16)
11361168
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
11371169
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
1170+
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)

runtime/executor/program.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
6767
// See if the program size is in the header.
6868
size_t program_size = 0;
6969
size_t segment_base_offset = 0;
70+
size_t segment_data_size = 0;
7071
{
7172
EXECUTORCH_SCOPE_PROF("Program::check_header");
7273
Result<FreeableBuffer> header = loader->load(
@@ -82,6 +83,24 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
8283
// The header has the program size.
8384
program_size = eh->program_size;
8485
segment_base_offset = eh->segment_base_offset;
86+
segment_data_size = eh->segment_data_size;
87+
88+
// segment_data_size was added in ET 1.0 release. For BC, only check the
89+
// expected file size when there are no segments or when segment_data_size
90+
// is positive (0-value may indicate no segments)
91+
if ((segment_data_size == 0 && segment_base_offset == 0) ||
92+
segment_data_size > 0) {
93+
size_t expected = segment_base_offset == 0
94+
? program_size
95+
: segment_base_offset + segment_data_size;
96+
size_t actual = loader->size().get();
97+
ET_CHECK_OR_RETURN_ERROR(
98+
expected <= actual,
99+
InvalidProgram,
100+
"File size is too small. Expected file size from extended header is %zu, actual file size from data loader is %zu",
101+
expected,
102+
actual);
103+
}
85104
} else if (eh.error() == Error::NotFound) {
86105
// No header; the program consumes the whole file, and there are no
87106
// segments.

runtime/executor/test/program_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,22 @@ TEST_F(ProgramTest, LoadFromMutableSegment) {
574574
&program.get(), 500, 1, 1, buffer);
575575
EXPECT_NE(err, Error::Ok);
576576
}
577+
578+
TEST_F(ProgramTest, LoadAndCheckPTESize) {
579+
// Load the serialized ModuleAddMul data, with constants in the segment.
580+
const char* linear_path = std::getenv("ET_MODULE_ADD_MUL_PATH");
581+
Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
582+
ASSERT_EQ(linear_loader.error(), Error::Ok);
583+
Result<Program> program = Program::load(&linear_loader.get());
584+
ASSERT_EQ(program.error(), Error::Ok);
585+
586+
// Create a truncated file.
587+
Result<FreeableBuffer> truncated_file = linear_loader->load(
588+
0, 200, DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
589+
ASSERT_EQ(truncated_file.error(), Error::Ok);
590+
591+
BufferDataLoader truncated_loader =
592+
BufferDataLoader(truncated_file->data(), 200);
593+
Result<Program> truncated_program = Program::load(&truncated_loader);
594+
ASSERT_EQ(truncated_program.error(), Error::InvalidProgram);
595+
}

schema/extended_header.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ static constexpr size_t kHeaderSegmentBaseOffsetOffset =
4141
static constexpr size_t kMinimumHeaderLength =
4242
kHeaderSegmentBaseOffsetOffset + sizeof(uint64_t);
4343

44+
/// The expected location of the segment_data_size field relative to the
45+
/// beginning of the header.
46+
static constexpr size_t kHeaderSegmentDataSizeOffset =
47+
kHeaderSegmentBaseOffsetOffset + sizeof(uint64_t);
48+
49+
/// The expected length of the header, including the segment_data_size field.
50+
static constexpr size_t kHeaderLengthWithSegmentDataSize =
51+
kHeaderSegmentDataSizeOffset + sizeof(uint64_t);
52+
4453
/// Interprets the 4 bytes at `data` as a little-endian uint32_t.
4554
uint32_t GetUInt32LE(const uint8_t* data) {
4655
return (uint32_t)data[0] | ((uint32_t)data[1] << 8) |
@@ -83,11 +92,17 @@ uint64_t GetUInt64LE(const uint8_t* data) {
8392
return Error::InvalidProgram;
8493
}
8594

95+
uint64_t segment_data_size = 0;
96+
if (header_length >= kHeaderLengthWithSegmentDataSize) {
97+
segment_data_size = GetUInt64LE(header + kHeaderSegmentDataSizeOffset);
98+
}
99+
86100
// The header is present and apparently valid.
87101
return ExtendedHeader{
88102
/*program_size=*/GetUInt64LE(header + kHeaderProgramSizeOffset),
89103
/*segment_base_offset=*/
90104
GetUInt64LE(header + kHeaderSegmentBaseOffsetOffset),
105+
/*segment_data_size=*/segment_data_size,
91106
};
92107
}
93108

schema/extended_header.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ struct ExtendedHeader {
7070
* is present.
7171
*/
7272
uint64_t segment_base_offset;
73+
74+
/**
75+
* The size of all the segment data, in bytes. Zero if:
76+
* - no segment is present
77+
* - the segment_data_size field doesn't exist in the header - the case for
78+
* older PTE files.
79+
*/
80+
uint64_t segment_data_size;
7381
};
7482

7583
} // namespace runtime

0 commit comments

Comments
 (0)