6262 DoubleList ,
6363 EValue ,
6464 ExecutionPlan ,
65+ ExtraTensorInfo ,
6566 FreeCall ,
6667 Instruction ,
6768 Int ,
@@ -120,6 +121,14 @@ class _ProgramState:
120121 # and should be copied to Program.backend_delegate_data.
121122 backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
122123
124+ # Constants are optionally stored in external files.
125+ # Aggregate unique external constants into one buffer.
126+ external_constant_buffer : List [bytes ] = field (default_factory = list )
127+ external_constant_hash : Dict [str , int ] = field (default_factory = dict )
128+ # Each constant_tag groups a set of constants together.
129+ # {constant_tag: {fqn: index into external_constant_buffer}}
130+ external_constant_map : Dict [str , Dict [str , int ]] = field (default_factory = dict )
131+
123132
124133@dataclass
125134class _EmitterState :
@@ -328,7 +337,9 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
328337 ExportErrorType .NOT_SUPPORTED , f"Unknown list type: { val_type } "
329338 )
330339
331- def _tensor_spec_to_evalue (self , spec : TensorSpec ) -> EValue :
340+ def _tensor_spec_to_evalue (
341+ self , spec : TensorSpec , constant_tag : Optional [str ] = None
342+ ) -> EValue :
332343 """Constructs an EValue from the given TensorSpec."""
333344
334345 allocation_info = None
@@ -389,6 +400,8 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
389400 buffer_idx = self .program_state .cached_spec_mutable_hash_values .get (
390401 hashed , - 1
391402 )
403+ elif spec .location == DataLocation .EXTERNAL :
404+ buffer_idx = self .program_state .external_constant_hash .get (hashed , - 1 )
392405 else :
393406 buffer_idx = self .program_state .cached_spec_hash_values .get (hashed , - 1 )
394407
@@ -405,6 +418,23 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
405418 buffer_idx
406419 )
407420 self .program_state .mutable_buffer .append (buffer )
421+
422+ # Constant tensor, stored in external file.
423+ elif spec .location == DataLocation .EXTERNAL :
424+ assert (
425+ spec .extra_tensor_info is not None
426+ and spec .extra_tensor_info .fully_qualified_name is not None
427+ ), "Fully qualified name is not set for external tensor"
428+ buffer_idx = len (self .program_state .external_constant_buffer )
429+ self .program_state .external_constant_hash [hashed ] = buffer_idx
430+ self .program_state .external_constant_buffer .append (buffer_data )
431+ if constant_tag :
432+ if constant_tag not in self .program_state .external_constant_map :
433+ self .program_state .external_constant_map [constant_tag ] = {}
434+ self .program_state .external_constant_map [constant_tag ][
435+ spec .extra_tensor_info .fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
436+ ] = buffer_idx
437+ # Constant tensor, stored in PTE.
408438 else :
409439 buffer_idx = len (self .program_state .constant_buffer )
410440 self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
@@ -1539,11 +1569,24 @@ def placeholder(
15391569 https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
15401570 """
15411571 spec = self .node .meta ["spec" ]
1572+ constant_tag = self .node .meta .get ("constant_tag" , None )
15421573 is_user_input = True
15431574
15441575 if isinstance (target , str ) and isinstance (spec , TensorSpec ):
15451576 fqn , is_mutable_buffer = self ._find_fqn_for_placeholder (target , spec )
15461577
1578+ # If the placeholder has a constant_tag, it is external to the PTE file
1579+ # and requires a fqn and location=DataLocation.EXTERNAL
1580+ if constant_tag is not None :
1581+ assert (
1582+ fqn is not None
1583+ ), "constant tagged tensors require a fully qualified name"
1584+ if spec .extra_tensor_info is None :
1585+ spec .extra_tensor_info = ExtraTensorInfo (fully_qualified_name = fqn )
1586+ else :
1587+ spec .extra_tensor_info .fully_qualified_name = fqn
1588+ spec .location = DataLocation .EXTERNAL
1589+
15471590 # From the fqn find the corresponding tensor
15481591 real_tensor = None
15491592 if fqn in self .exported_program .state_dict :
@@ -1581,7 +1624,7 @@ def placeholder(
15811624 spec .const = not (is_user_input or is_mutable_buffer )
15821625
15831626 evalue = (
1584- self ._tensor_spec_to_evalue (spec )
1627+ self ._tensor_spec_to_evalue (spec , constant_tag )
15851628 if isinstance (spec , TensorSpec )
15861629 else self ._constant_to_evalue (spec , None )
15871630 )
0 commit comments