diff --git a/exir/capture/_config.py b/exir/capture/_config.py index abb7aa74b93..9267af4f2dc 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -97,3 +97,8 @@ class ExecutorchBackendConfig: # If set to true, all trainable weights will be stored in a separate file, # external to the PTE file. external_mutable_weights: bool = False + + # If set to true, all mutable buffers will have their fully qualified names + # serialized in the PTE file. Its value is ignored if mutable buffers are not + # memory planned as the names must be serialized in that case. + emit_mutable_buffer_names: bool = False diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index f9571143a1b..f456626feed 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -118,6 +118,7 @@ def emit_program( methods: Union[ExportedProgram, Dict[str, ExportedProgram]], emit_stacktrace: bool = False, prim_getters: Optional[Dict[str, Any]] = None, + emit_mutable_buffer_names: bool = False, ) -> EmitterOutput: """ Given a exported program, it returns the program in the format @@ -163,6 +164,7 @@ def emit_program( operator_cache={}, delegate_cache={}, emit_stacktrace=emit_stacktrace, + emit_mutable_buffer_names=emit_mutable_buffer_names, ) gm = _remove_non_user_outputs(exported_program) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 0cbc63bde21..9cc8a1e809c 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -149,6 +149,7 @@ class _EmitterState: # delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates delegate_cache: Dict[str, int] emit_stacktrace: bool + emit_mutable_buffer_names: bool spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict) @@ -1610,7 +1611,7 @@ def _find_fqn_for_placeholder( ) return fqn, is_mutable_buffer - def placeholder( + def placeholder( # noqa: C901 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> _AbstractValue: """Emits the value within the placeholder node. @@ -1639,6 +1640,13 @@ def placeholder( else: spec.extra_tensor_info.fully_qualified_name = fqn spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL + if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer: + if spec.extra_tensor_info is None: + spec.extra_tensor_info = ExtraTensorInfo( + fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT + ) + else: + spec.extra_tensor_info.fully_qualified_name = fqn # From the fqn find the corresponding tensor real_tensor = None diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 362796146ee..8d68eb8af7a 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1819,3 +1819,27 @@ def forward(self, input, label): ] self.assertEqual(external_map["net.linear.weight"], 0) self.assertEqual(external_map["net.linear.bias"], 1) + + def test_emit_mutable_buffer_names(self) -> None: + class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.register_buffer("buffer", torch.zeros(1, 2)) + + def forward(self, x): + self.buffer.add_(1) + return self.linear(x) + self.buffer + + net = Net() + + ep = export(net, (torch.randn(1, 2),), strict=True) + # Lower the graph to edge dialect. + ep = to_edge(ep) + # Lower the graph to executorch. + ep = ep.to_executorch( + config=ExecutorchBackendConfig(emit_mutable_buffer_names=True) + ) + for val in ep.executorch_program.execution_plan[0].values: + if isinstance(val, Tensor) and val.extra_tensor_info: + self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer") diff --git a/exir/program/_program.py b/exir/program/_program.py index 7a2120f9e9b..ef857ffd011 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1612,6 +1612,7 @@ def __init__( self._execution_programs, backend_config.emit_stacktrace, self._config_methods, + backend_config.emit_mutable_buffer_names, ) # Serialize emitter output, ready to be written to a file.