6363    DoubleList ,
6464    EValue ,
6565    ExecutionPlan ,
66+     ExtraTensorInfo ,
6667    FreeCall ,
6768    Instruction ,
6869    Int ,
7677    ScalarType ,
7778    String ,
7879    Tensor ,
80+     TensorDataLocation ,
7981    TensorList ,
8082    TensorShapeDynamism ,
8183)
@@ -121,6 +123,14 @@ class _ProgramState:
121123    # and should be copied to Program.backend_delegate_data. 
122124    backend_delegate_data : List [BackendDelegateInlineData ] =  field (default_factory = list )
123125
126+     # Constants are optionally stored in external files. 
127+     # Aggregate unique external constants into one buffer. 
128+     external_constant_buffer : List [bytes ] =  field (default_factory = list )
129+     external_constant_hash : Dict [str , int ] =  field (default_factory = dict )
130+     # Each constant_tag groups a set of constants together. 
131+     # {constant_tag: {fqn: index into external_constant_buffer}} 
132+     external_constant_map : Dict [str , Dict [str , int ]] =  field (default_factory = dict )
133+ 
124134
125135@dataclass  
126136class  _EmitterState :
@@ -363,7 +373,8 @@ def _save_new_const_tensor(
363373        spec : TensorSpec ,
364374        buffer_data : bytes ,
365375        hashed : str ,
366-         allocation_info : Optional [AllocationDetails ],
376+         allocation_info : Optional [AllocationDetails ] =  None ,
377+         constant_tag : Optional [str ] =  None ,
367378    ) ->  int :
368379        """Saves a new constant tensor to the constant buffer and returns the buffer idx""" 
369380
@@ -372,17 +383,45 @@ def _save_new_const_tensor(
372383
373384        # Update buffer_idx to point to the end of the list where we are adding the new buffer. 
374385        buffer  =  Buffer (storage = buffer_data )
386+ 
387+         # Tensor is mutable with initial state. 
375388        if  allocation_info :
376389            buffer_idx  =  len (self .program_state .mutable_buffer )
377390            self .program_state .cached_spec_mutable_hash_values [hashed ] =  buffer_idx 
378391            self .program_state .mutable_buffer .append (buffer )
392+ 
393+         # Tensor is constant. 
379394        else :
380-             buffer_idx  =  len (self .program_state .constant_buffer )
381-             self .program_state .cached_spec_hash_values [hashed ] =  buffer_idx 
382-             self .program_state .constant_buffer .append (buffer )
395+             # Tensor is stored outside of the PTE file. 
396+             if  (
397+                 spec .extra_tensor_info  is  not None 
398+                 and  spec .extra_tensor_info .fully_qualified_name  is  not None 
399+                 and  spec .extra_tensor_info .location  ==  TensorDataLocation .EXTERNAL 
400+             ):
401+                 assert  (
402+                     constant_tag  is  not None 
403+                 ), "Constant tag is not set for external tensor" 
404+ 
405+                 buffer_idx  =  len (self .program_state .external_constant_buffer )
406+                 self .program_state .external_constant_hash [hashed ] =  buffer_idx 
407+                 self .program_state .external_constant_buffer .append (buffer_data )
408+                 if  constant_tag  not  in self .program_state .external_constant_map :
409+                     self .program_state .external_constant_map [constant_tag ] =  {}
410+                 self .program_state .external_constant_map [constant_tag ][
411+                     spec .extra_tensor_info .fully_qualified_name   # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. 
412+                 ] =  buffer_idx 
413+ 
414+             # Tensor is stored in the PTE file. 
415+             else :
416+                 buffer_idx  =  len (self .program_state .constant_buffer )
417+                 self .program_state .cached_spec_hash_values [hashed ] =  buffer_idx 
418+                 self .program_state .constant_buffer .append (buffer )
419+ 
383420        return  buffer_idx 
384421
385-     def  _tensor_spec_to_evalue (self , spec : TensorSpec ) ->  EValue :
422+     def  _tensor_spec_to_evalue (
423+         self , spec : TensorSpec , constant_tag : Optional [str ] =  None 
424+     ) ->  EValue :
386425        """Constructs an EValue from the given TensorSpec.""" 
387426
388427        allocation_info  =  None 
@@ -420,13 +459,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
420459                buffer_idx  =  self .program_state .cached_spec_mutable_hash_values .get (
421460                    hashed , - 1 
422461                )
462+             elif  (
463+                 spec .extra_tensor_info  is  not None 
464+                 and  spec .extra_tensor_info .location  ==  TensorDataLocation .EXTERNAL 
465+             ):
466+                 buffer_idx  =  self .program_state .external_constant_hash .get (hashed , - 1 )
423467            else :
424468                buffer_idx  =  self .program_state .cached_spec_hash_values .get (hashed , - 1 )
425469
426470            # Haven't seen this constant before. 
427471            if  buffer_idx  ==  - 1 :
428472                buffer_idx  =  self ._save_new_const_tensor (
429-                     spec , buffer_data , hashed , allocation_info 
473+                     spec , buffer_data , hashed , allocation_info ,  constant_tag 
430474                )
431475
432476            if  spec .const  and  spec .nbytes () !=  len (buffer_data ):
@@ -1557,11 +1601,26 @@ def placeholder(
15571601        https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder 
15581602        """ 
15591603        spec  =  self .node .meta ["spec" ]
1604+         constant_tag  =  self .node .meta .get ("constant_tag" , None )
15601605        is_user_input  =  True 
15611606
15621607        if  isinstance (target , str ) and  isinstance (spec , TensorSpec ):
15631608            fqn , is_mutable_buffer  =  self ._find_fqn_for_placeholder (target , spec )
15641609
1610+             # If the placeholder has a constant_tag, it is external to the PTE file 
1611+             # and requires a fqn and location=TensorDataLocation.EXTERNAL 
1612+             if  constant_tag  is  not None :
1613+                 assert  (
1614+                     fqn  is  not None 
1615+                 ), "constant tagged tensors require a fully qualified name" 
1616+                 if  spec .extra_tensor_info  is  None :
1617+                     spec .extra_tensor_info  =  ExtraTensorInfo (
1618+                         fully_qualified_name = fqn , location = TensorDataLocation .EXTERNAL 
1619+                     )
1620+                 else :
1621+                     spec .extra_tensor_info .fully_qualified_name  =  fqn 
1622+                     spec .extra_tensor_info .location  =  TensorDataLocation .EXTERNAL 
1623+ 
15651624            # From the fqn find the corresponding tensor 
15661625            real_tensor  =  None 
15671626            if  fqn  in  self .exported_program .state_dict :
@@ -1599,7 +1658,7 @@ def placeholder(
15991658            spec .const  =  not  (is_user_input  or  is_mutable_buffer )
16001659
16011660        evalue  =  (
1602-             self ._tensor_spec_to_evalue (spec )
1661+             self ._tensor_spec_to_evalue (spec ,  constant_tag )
16031662            if  isinstance (spec , TensorSpec )
16041663            else  self ._constant_to_evalue (spec , None )
16051664        )
0 commit comments