6363 DoubleList ,
6464 EValue ,
6565 ExecutionPlan ,
66+ ExtraTensorInfo ,
6667 FreeCall ,
6768 Instruction ,
6869 Int ,
7677 ScalarType ,
7778 String ,
7879 Tensor ,
80+ TensorDataLocation ,
7981 TensorList ,
8082 TensorShapeDynamism ,
8183)
@@ -121,6 +123,14 @@ class _ProgramState:
121123 # and should be copied to Program.backend_delegate_data.
122124 backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
123125
126+ # Constants are optionally stored in external files.
127+ # Aggregate unique external constants into one buffer.
128+ external_constant_buffer : List [bytes ] = field (default_factory = list )
129+ external_constant_hash : Dict [str , int ] = field (default_factory = dict )
130+ # Each constant_tag groups a set of constants together.
131+ # {constant_tag: {fqn: index into external_constant_buffer}}
132+ external_constant_map : Dict [str , Dict [str , int ]] = field (default_factory = dict )
133+
124134
125135@dataclass
126136class _EmitterState :
@@ -363,7 +373,8 @@ def _save_new_const_tensor(
363373 spec : TensorSpec ,
364374 buffer_data : bytes ,
365375 hashed : str ,
366- allocation_info : Optional [AllocationDetails ],
376+ allocation_info : Optional [AllocationDetails ] = None ,
377+ constant_tag : Optional [str ] = None ,
367378 ) -> int :
368379 """Saves a new constant tensor to the constant buffer and returns the buffer idx"""
369380
@@ -372,17 +383,45 @@ def _save_new_const_tensor(
372383
373384 # Update buffer_idx to point to the end of the list where we are adding the new buffer.
374385 buffer = Buffer (storage = buffer_data )
386+
387+ # Tensor is mutable with initial state.
375388 if allocation_info :
376389 buffer_idx = len (self .program_state .mutable_buffer )
377390 self .program_state .cached_spec_mutable_hash_values [hashed ] = buffer_idx
378391 self .program_state .mutable_buffer .append (buffer )
392+
393+ # Tensor is constant.
379394 else :
380- buffer_idx = len (self .program_state .constant_buffer )
381- self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
382- self .program_state .constant_buffer .append (buffer )
395+ # Tensor is stored outside of the PTE file.
396+ if (
397+ spec .extra_tensor_info is not None
398+ and spec .extra_tensor_info .fully_qualified_name is not None
399+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
400+ ):
401+ assert (
402+ constant_tag is not None
403+ ), "Constant tag is not set for external tensor"
404+
405+ buffer_idx = len (self .program_state .external_constant_buffer )
406+ self .program_state .external_constant_hash [hashed ] = buffer_idx
407+ self .program_state .external_constant_buffer .append (buffer_data )
408+ if constant_tag not in self .program_state .external_constant_map :
409+ self .program_state .external_constant_map [constant_tag ] = {}
410+ self .program_state .external_constant_map [constant_tag ][
411+ spec .extra_tensor_info .fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
412+ ] = buffer_idx
413+
414+ # Tensor is stored in the PTE file.
415+ else :
416+ buffer_idx = len (self .program_state .constant_buffer )
417+ self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
418+ self .program_state .constant_buffer .append (buffer )
419+
383420 return buffer_idx
384421
385- def _tensor_spec_to_evalue (self , spec : TensorSpec ) -> EValue :
422+ def _tensor_spec_to_evalue (
423+ self , spec : TensorSpec , constant_tag : Optional [str ] = None
424+ ) -> EValue :
386425 """Constructs an EValue from the given TensorSpec."""
387426
388427 allocation_info = None
@@ -420,13 +459,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
420459 buffer_idx = self .program_state .cached_spec_mutable_hash_values .get (
421460 hashed , - 1
422461 )
462+ elif (
463+ spec .extra_tensor_info is not None
464+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
465+ ):
466+ buffer_idx = self .program_state .external_constant_hash .get (hashed , - 1 )
423467 else :
424468 buffer_idx = self .program_state .cached_spec_hash_values .get (hashed , - 1 )
425469
426470 # Haven't seen this constant before.
427471 if buffer_idx == - 1 :
428472 buffer_idx = self ._save_new_const_tensor (
429- spec , buffer_data , hashed , allocation_info
473+ spec , buffer_data , hashed , allocation_info , constant_tag
430474 )
431475
432476 if spec .const and spec .nbytes () != len (buffer_data ):
@@ -1557,11 +1601,26 @@ def placeholder(
15571601 https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
15581602 """
15591603 spec = self .node .meta ["spec" ]
1604+ constant_tag = self .node .meta .get ("constant_tag" , None )
15601605 is_user_input = True
15611606
15621607 if isinstance (target , str ) and isinstance (spec , TensorSpec ):
15631608 fqn , is_mutable_buffer = self ._find_fqn_for_placeholder (target , spec )
15641609
1610+ # If the placeholder has a constant_tag, it is external to the PTE file
1611+ # and requires a fqn and location=TensorDataLocation.EXTERNAL
1612+ if constant_tag is not None :
1613+ assert (
1614+ fqn is not None
1615+ ), "constant tagged tensors require a fully qualified name"
1616+ if spec .extra_tensor_info is None :
1617+ spec .extra_tensor_info = ExtraTensorInfo (
1618+ fully_qualified_name = fqn , location = TensorDataLocation .EXTERNAL
1619+ )
1620+ else :
1621+ spec .extra_tensor_info .fully_qualified_name = fqn
1622+ spec .extra_tensor_info .location = TensorDataLocation .EXTERNAL
1623+
15651624 # From the fqn find the corresponding tensor
15661625 real_tensor = None
15671626 if fqn in self .exported_program .state_dict :
@@ -1599,7 +1658,7 @@ def placeholder(
15991658 spec .const = not (is_user_input or is_mutable_buffer )
16001659
16011660 evalue = (
1602- self ._tensor_spec_to_evalue (spec )
1661+ self ._tensor_spec_to_evalue (spec , constant_tag )
16031662 if isinstance (spec , TensorSpec )
16041663 else self ._constant_to_evalue (spec , None )
16051664 )
0 commit comments