- 
                Notifications
    You must be signed in to change notification settings 
- Fork 706
Deserialize with named data store #14743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -576,7 +576,7 @@ def serialize_pte_binary( | |
| return pte_data | ||
|  | ||
|  | ||
| def _restore_segments(program: Program, segment_data: bytes) -> Program: | ||
| def _restore_segments(program: Program, segment_data: bytes) -> Program: # noqa: C901 | ||
| """Moves segments from `segment_data` into `program`. | ||
|  | ||
| This should recreate the original Program that the segments were extracted | ||
|  | @@ -641,6 +641,48 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: | |
| program.constant_segment.segment_index = 0 | ||
| program.constant_segment.offsets = [] | ||
|  | ||
| # Reconstruct named data blobs from segment data when present. | ||
| if program.named_data: | ||
| segment_to_buffer_index: Dict[int, int] = {} | ||
| named_buffers: List[BufferEntry] = [] | ||
| key_to_buffer_index: Dict[str, int] = {} | ||
|  | ||
| for entry in program.named_data: | ||
| segment_index = entry.segment_index | ||
| if segment_index >= len(segments): | ||
| raise ValueError( | ||
| "Named data segment index " | ||
| f"{segment_index} >= num segments {len(segments)}" | ||
| ) | ||
|  | ||
| buffer_index = segment_to_buffer_index.get(segment_index) | ||
| if buffer_index is None: | ||
| buffer_index = len(named_buffers) | ||
| segment_to_buffer_index[segment_index] = buffer_index | ||
| named_buffers.append( | ||
| BufferEntry(buffer=segments[segment_index], alignment=1) | ||
| ) | ||
|  | ||
| key_to_buffer_index[entry.key] = buffer_index | ||
|  | ||
| named_data_store = NamedDataStoreOutput( | ||
| buffers=named_buffers, | ||
| pte_data=key_to_buffer_index, | ||
| external_data={}, | ||
| ) | ||
| # Keep a convenient mapping from key to raw bytes for callers that only | ||
| # need to read the blobs. | ||
| setattr( # noqa: B010 | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think we don't need this if we have the named_data_store in line 683. Not sure about attaching  | ||
| program, | ||
| "named_data_blobs", | ||
| { | ||
| key: named_data_store.buffers[idx].buffer | ||
| for key, idx in named_data_store.pte_data.items() | ||
| }, | ||
| ) | ||
| setattr(program, "named_data_store", named_data_store) # noqa: B010 | ||
| program.named_data = [] | ||
|  | ||
| # Clear out the segments list since the original Program didn't have one. | ||
| program.segments = [] | ||
| return program | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -893,6 +893,23 @@ def test_constant_delegate_and_named_data_segments(self) -> None: | |
| self.assertEqual(program2.execution_plan, program.execution_plan) | ||
| # Number of constant tensors should be the same. | ||
| self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) | ||
| # Named data should be restored in a named data store and removed from the Program. | ||
| self.assertEqual(program2.named_data, []) | ||
| self.assertTrue(hasattr(program2, "named_data_store")) | ||
| named_store = program2.named_data_store | ||
| self.assertEqual(named_store.pte_data, pte_named_data) | ||
| # Buffers in the restored store should match the original serialized blobs. | ||
| restored = { | ||
| key: named_store.buffers[idx].buffer | ||
| for key, idx in named_store.pte_data.items() | ||
| } | ||
| original = { | ||
| key: named_data_buffers[buf_idx].buffer | ||
| for key, buf_idx in pte_named_data.items() | ||
| } | ||
| self.assertEqual(restored, original) | ||
| self.assertTrue(hasattr(program2, "named_data_blobs")) | ||
| self.assertEqual(program2.named_data_blobs, restored) | ||
|  | ||
| def test_named_data_segments(self) -> None: | ||
| # Set segment alignment to 12 to test the padding. | ||
|  | @@ -997,6 +1014,21 @@ def test_named_data_segments(self) -> None: | |
| buffers[2].buffer, | ||
| ) | ||
|  | ||
| program2 = deserialize_pte_binary(pte_data) | ||
| self.assertEqual(program2.named_data, []) | ||
| self.assertTrue(hasattr(program2, "named_data_store")) | ||
| store = program2.named_data_store | ||
| self.assertEqual(store.pte_data, pte_named_data) | ||
| restored_named_data = { | ||
| key: store.buffers[idx].buffer for key, idx in store.pte_data.items() | ||
| } | ||
| self.assertEqual( | ||
| restored_named_data, | ||
| {key: buffers[idx].buffer for key, idx in pte_named_data.items()}, | ||
| ) | ||
| self.assertTrue(hasattr(program2, "named_data_blobs")) | ||
| self.assertEqual(program2.named_data_blobs, restored_named_data) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test the roundtrip --> can re-serialize the program with program2.named_data_blobs? | ||
|  | ||
|  | ||
| # Common data for extended header tests. The two example values should produce | ||
| # the example data. | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?