Skip to content

Commit 8c07f59

Browse files
committed
Update base for Update on "[executorch][passes] Add config and pass to tag constants for external file"
- Add config 'external_constants' to ExecutorchBackendConfig. - When set to True, run the 'external_constants_pass' - This tags all constants as external, and moves them into a separate buffer to be serialized outside of the PTE file. Note: users can write their own passes to tag weights to specific files / multiple files. TODO: write example pass and test for the case where we have two constant files. Differential Revision: [D66560903](https://our.internmc.facebook.com/intern/diff/D66560903/) [ghstack-poisoned]
2 parents ca7fedc + 538bfaf commit 8c07f59

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

exir/emit/_emitter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
ScalarType,
7878
String,
7979
Tensor,
80+
TensorDataLocation,
8081
TensorList,
8182
TensorShapeDynamism,
8283
)
@@ -372,8 +373,8 @@ def _save_new_const_tensor(
372373
spec: TensorSpec,
373374
buffer_data: bytes,
374375
hashed: str,
375-
allocation_info: Optional[AllocationDetails],
376-
constant_tag: str,
376+
allocation_info: Optional[AllocationDetails] = None,
377+
constant_tag: Optional[str] = None,
377378
) -> int:
378379
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
379380

@@ -395,7 +396,7 @@ def _save_new_const_tensor(
395396
if (
396397
spec.extra_tensor_info is not None
397398
and spec.extra_tensor_info.fully_qualified_name is not None
398-
and spec.extra_tensor_info.location == DataLocation.EXTERNAL
399+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
399400
):
400401
assert (
401402
constant_tag is not None
@@ -460,7 +461,7 @@ def _tensor_spec_to_evalue(
460461
)
461462
elif (
462463
spec.extra_tensor_info is not None
463-
and spec.extra_tensor_info.location == DataLocation.EXTERNAL
464+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
464465
):
465466
buffer_idx = self.program_state.external_constant_hash.get(hashed, -1)
466467
else:
@@ -1614,7 +1615,7 @@ def placeholder(
16141615
), "constant tagged tensors require a fully qualified name"
16151616
if spec.extra_tensor_info is None:
16161617
spec.extra_tensor_info = ExtraTensorInfo(
1617-
fully_qualified_name=fqn, location=DataLocation.EXTERNAL
1618+
fully_qualified_name=fqn, location=TensorDataLocation.EXTERNAL
16181619
)
16191620
else:
16201621
spec.extra_tensor_info.fully_qualified_name = fqn

0 commit comments

Comments
 (0)