6262    DoubleList ,
6363    EValue ,
6464    ExecutionPlan ,
65+     ExtraTensorInfo ,
6566    FreeCall ,
6667    Instruction ,
6768    Int ,
@@ -120,6 +121,14 @@ class _ProgramState:
120121    # and should be copied to Program.backend_delegate_data. 
121122    backend_delegate_data : List [BackendDelegateInlineData ] =  field (default_factory = list )
122123
124+     # Constants are optionally stored in external files. 
125+     # Aggregate unique external constants into one buffer. 
126+     external_constant_buffer : List [bytes ] =  field (default_factory = list )
127+     external_constant_hash : Dict [str , int ] =  field (default_factory = dict )
128+     # Each constant_tag groups a set of constants together. 
129+     # {constant_tag: {fqn: index into external_constant_buffer}} 
130+     external_constant_map : Dict [str , Dict [str , int ]] =  field (default_factory = dict )
131+ 
123132
124133@dataclass  
125134class  _EmitterState :
@@ -328,7 +337,9 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
328337            ExportErrorType .NOT_SUPPORTED , f"Unknown list type: { val_type }  " 
329338        )
330339
331-     def  _tensor_spec_to_evalue (self , spec : TensorSpec ) ->  EValue :
340+     def  _tensor_spec_to_evalue (
341+         self , spec : TensorSpec , constant_tag : Optional [str ] =  None 
342+     ) ->  EValue :
332343        """Constructs an EValue from the given TensorSpec.""" 
333344
334345        allocation_info  =  None 
@@ -389,6 +400,8 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
389400                buffer_idx  =  self .program_state .cached_spec_mutable_hash_values .get (
390401                    hashed , - 1 
391402                )
403+             elif  spec .location  ==  DataLocation .EXTERNAL :
404+                 buffer_idx  =  self .program_state .external_constant_hash .get (hashed , - 1 )
392405            else :
393406                buffer_idx  =  self .program_state .cached_spec_hash_values .get (hashed , - 1 )
394407
@@ -405,6 +418,23 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
405418                        buffer_idx 
406419                    )
407420                    self .program_state .mutable_buffer .append (buffer )
421+ 
422+                 # Constant tensor, stored in external file. 
423+                 elif  spec .location  ==  DataLocation .EXTERNAL :
424+                     assert  (
425+                         spec .extra_tensor_info  is  not   None 
426+                         and  spec .extra_tensor_info .fully_qualified_name  is  not   None 
427+                     ), "Fully qualified name is not set for external tensor" 
428+                     buffer_idx  =  len (self .program_state .external_constant_buffer )
429+                     self .program_state .external_constant_hash [hashed ] =  buffer_idx 
430+                     self .program_state .external_constant_buffer .append (buffer_data )
431+                     if  constant_tag :
432+                         if  constant_tag  not  in   self .program_state .external_constant_map :
433+                             self .program_state .external_constant_map [constant_tag ] =  {}
434+                         self .program_state .external_constant_map [constant_tag ][
435+                             spec .extra_tensor_info .fully_qualified_name   # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. 
436+                         ] =  buffer_idx 
437+                 # Constant tensor, stored in PTE. 
408438                else :
409439                    buffer_idx  =  len (self .program_state .constant_buffer )
410440                    self .program_state .cached_spec_hash_values [hashed ] =  buffer_idx 
@@ -1539,11 +1569,24 @@ def placeholder(
15391569        https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder 
15401570        """ 
15411571        spec  =  self .node .meta ["spec" ]
1572+         constant_tag  =  self .node .meta .get ("constant_tag" , None )
15421573        is_user_input  =  True 
15431574
15441575        if  isinstance (target , str ) and  isinstance (spec , TensorSpec ):
15451576            fqn , is_mutable_buffer  =  self ._find_fqn_for_placeholder (target , spec )
15461577
1578+             # If the placeholder has a constant_tag, it is external to the PTE file 
1579+             # and requires a fqn and location=DataLocation.EXTERNAL 
1580+             if  constant_tag  is  not   None :
1581+                 assert  (
1582+                     fqn  is  not   None 
1583+                 ), "constant tagged tensors require a fully qualified name" 
1584+                 if  spec .extra_tensor_info  is  None :
1585+                     spec .extra_tensor_info  =  ExtraTensorInfo (fully_qualified_name = fqn )
1586+                 else :
1587+                     spec .extra_tensor_info .fully_qualified_name  =  fqn 
1588+                 spec .location  =  DataLocation .EXTERNAL 
1589+ 
15471590            # From the fqn find the corresponding tensor 
15481591            real_tensor  =  None 
15491592            if  fqn  in  self .exported_program .state_dict :
@@ -1581,7 +1624,7 @@ def placeholder(
15811624            spec .const  =  not  (is_user_input  or  is_mutable_buffer )
15821625
15831626        evalue  =  (
1584-             self ._tensor_spec_to_evalue (spec )
1627+             self ._tensor_spec_to_evalue (spec ,  constant_tag )
15851628            if  isinstance (spec , TensorSpec )
15861629            else  self ._constant_to_evalue (spec , None )
15871630        )
0 commit comments