Skip to content

Commit 914dace

Browse files
committed
[executorch][emitter] Emit FQNs
Emit FQNs for external tensors. In the emitter, store external tensors as: ``` // list of unique tensors external_constants_buffer: List[bytes] // map of {constant_tag: {fqn: index into external_constant_buffer}} // constant_tag: may want to save multiple external constant files; group them together via the tag. // {fqn: index}; there may be multiple fqns pointing to the same data buffer. This is for deduplication. external_constants_map: [Dict[str, Dict[str, int]] ``` Differential Revision: [D66523226](https://our.internmc.facebook.com/intern/diff/D66523226/) ghstack-source-id: 256647498 Pull Request resolved: #7192
1 parent d985fdf commit 914dace

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

exir/emit/_emit_program.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ class EmitterOutput:
4747

4848
mutable_data: Optional[List[Buffer]]
4949

50+
# Constants are optionally stored in external files.
51+
# Aggregate unique external constants into one buffer.
52+
external_constant_buffer: List[bytes]
53+
# Each constant_tag groups a set of constants together.
54+
# {constant_tag: {fqn: index into external_constant_buffer}}
55+
external_constant_map: Optional[Dict[str, Dict[str, int]]]
56+
5057

5158
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
5259
gm = exported_program.graph_module
@@ -199,4 +206,6 @@ def emit_program(
199206
if len(program_state.mutable_buffer) > 1
200207
else None
201208
),
209+
external_constant_buffer=program_state.external_constant_buffer,
210+
external_constant_map=program_state.external_constant_map,
202211
)

exir/emit/_emitter.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
DoubleList,
6363
EValue,
6464
ExecutionPlan,
65+
ExtraTensorInfo,
6566
FreeCall,
6667
Instruction,
6768
Int,
@@ -120,6 +121,14 @@ class _ProgramState:
120121
# and should be copied to Program.backend_delegate_data.
121122
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
122123

124+
# Constants are optionally stored in external files.
125+
# Aggregate unique external constants into one buffer.
126+
external_constant_buffer: List[bytes] = field(default_factory=list)
127+
external_constant_hash: Dict[str, int] = field(default_factory=dict)
128+
# Each constant_tag groups a set of constants together.
129+
# {constant_tag: {fqn: index into external_constant_buffer}}
130+
external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict)
131+
123132

124133
@dataclass
125134
class _EmitterState:
@@ -328,7 +337,9 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
328337
ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}"
329338
)
330339

331-
def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
340+
def _tensor_spec_to_evalue(
341+
self, spec: TensorSpec, constant_tag: Optional[str] = None
342+
) -> EValue:
332343
"""Constructs an EValue from the given TensorSpec."""
333344

334345
allocation_info = None
@@ -389,6 +400,8 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
389400
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
390401
hashed, -1
391402
)
403+
elif spec.location == DataLocation.EXTERNAL:
404+
buffer_idx = self.program_state.external_constant_hash.get(hashed, -1)
392405
else:
393406
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
394407

@@ -405,6 +418,23 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
405418
buffer_idx
406419
)
407420
self.program_state.mutable_buffer.append(buffer)
421+
422+
# Constant tensor, stored in external file.
423+
elif spec.location == DataLocation.EXTERNAL:
424+
assert (
425+
spec.extra_tensor_info is not None
426+
and spec.extra_tensor_info.fully_qualified_name is not None
427+
), "Fully qualified name is not set for external tensor"
428+
buffer_idx = len(self.program_state.external_constant_buffer)
429+
self.program_state.external_constant_hash[hashed] = buffer_idx
430+
self.program_state.external_constant_buffer.append(buffer_data)
431+
if constant_tag:
432+
if constant_tag not in self.program_state.external_constant_map:
433+
self.program_state.external_constant_map[constant_tag] = {}
434+
self.program_state.external_constant_map[constant_tag][
435+
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
436+
] = buffer_idx
437+
# Constant tensor, stored in PTE.
408438
else:
409439
buffer_idx = len(self.program_state.constant_buffer)
410440
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
@@ -1539,11 +1569,24 @@ def placeholder(
15391569
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
15401570
"""
15411571
spec = self.node.meta["spec"]
1572+
constant_tag = self.node.meta.get("constant_tag", None)
15421573
is_user_input = True
15431574

15441575
if isinstance(target, str) and isinstance(spec, TensorSpec):
15451576
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
15461577

1578+
# If the placeholder has a constant_tag, it is external to the PTE file
1579+
# and requires a fqn and location=DataLocation.EXTERNAL
1580+
if constant_tag is not None:
1581+
assert (
1582+
fqn is not None
1583+
), "constant tagged tensors require a fully qualified name"
1584+
if spec.extra_tensor_info is None:
1585+
spec.extra_tensor_info = ExtraTensorInfo(fully_qualified_name=fqn)
1586+
else:
1587+
spec.extra_tensor_info.fully_qualified_name = fqn
1588+
spec.location = DataLocation.EXTERNAL
1589+
15471590
# From the fqn find the corresponding tensor
15481591
real_tensor = None
15491592
if fqn in self.exported_program.state_dict:
@@ -1581,7 +1624,7 @@ def placeholder(
15811624
spec.const = not (is_user_input or is_mutable_buffer)
15821625

15831626
evalue = (
1584-
self._tensor_spec_to_evalue(spec)
1627+
self._tensor_spec_to_evalue(spec, constant_tag)
15851628
if isinstance(spec, TensorSpec)
15861629
else self._constant_to_evalue(spec, None)
15871630
)

exir/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ExtraTensorInfo:
4949
Check program.fbs for explanations of this enum.
5050
"""
5151

52-
mutable_data_segments_idx: Optional[int] = None
52+
mutable_data_segments_idx: Optional[int] = 0
5353
fully_qualified_name: Optional[str] = None
5454

5555

0 commit comments

Comments
 (0)