Skip to content

Commit 1744514

Browse files
pytorchbotlucylq
andauthored
Deserialize to named data store output (#15484)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15469 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/122/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/122/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/122/orig Differential Revision: [D83510300](https://our.internmc.facebook.com/intern/diff/D83510300/) @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent a4b234f commit 1744514

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

extension/flat_tensor/serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ runtime.python_library(
2929
],
3030
visibility = [
3131
"//executorch/...",
32+
"@EXECUTORCH_CLIENTS",
3233
],
3334
deps = [
3435
":schema",

extension/flat_tensor/serialize/serialize.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.exir._serialize._cord import Cord
2121
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
2222
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
23+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
2324
from executorch.exir._serialize._program import _insert_flatbuffer_header
2425
from executorch.exir._serialize.data_serializer import (
2526
DataEntry,
@@ -389,6 +390,8 @@ def serialize(
389390
def deserialize(self, blob: Cord) -> DataPayload:
390391
"""
391392
Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
393+
394+
Note: deserialization does not preserve alignment information.
392395
"""
393396

394397
data = bytes(blob)
@@ -436,3 +439,14 @@ def deserialize(self, blob: Cord) -> DataPayload:
436439
payload.named_data[named_data.key] = entry
437440

438441
return payload
442+
443+
def deserialize_to_named_data_store_output(
444+
self, blob: bytes, name: str
445+
) -> NamedDataStoreOutput:
446+
bytes = Cord(blob)
447+
data_payload = self.deserialize(bytes)
448+
return NamedDataStoreOutput(
449+
buffers=data_payload.buffers,
450+
pte_data={},
451+
external_data={name: data_payload.named_data},
452+
)

0 commit comments

Comments
 (0)