@@ -387,38 +387,36 @@ def _save_new_const_tensor(
387387 # Update buffer_idx to point to the end of the list where we are adding the new buffer.
388388 buffer = Buffer (storage = buffer_data )
389389
390- # Tensor is mutable with initial state.
391- if allocation_info :
390+ # Tensor is stored outside of the PTE file.
391+ if (
392+ spec .extra_tensor_info is not None
393+ and spec .extra_tensor_info .fully_qualified_name is not None
394+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
395+ ):
396+ assert (
397+ constant_tag is not None
398+ ), "Constant tag is not set for external tensor"
399+ # TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file.
400+ # We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this.
401+
402+ buffer_idx = len (self .program_state .external_constant_buffer )
403+ self .program_state .external_constant_hash [hashed ] = buffer_idx
404+ self .program_state .external_constant_buffer .append (buffer_data )
405+ if constant_tag not in self .program_state .external_constant_map :
406+ self .program_state .external_constant_map [constant_tag ] = {}
407+ self .program_state .external_constant_map [constant_tag ][
408+ spec .extra_tensor_info .fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
409+ ] = buffer_idx
410+ # Tensor is mutable with initial state. Place into mutable segment
411+ elif allocation_info :
392412 buffer_idx = len (self .program_state .mutable_buffer )
393413 self .program_state .cached_spec_mutable_hash_values [hashed ] = buffer_idx
394414 self .program_state .mutable_buffer .append (buffer )
395-
396- # Tensor is constant.
415+ # Tensor is stored in the PTE file.
397416 else :
398- # Tensor is stored outside of the PTE file.
399- if (
400- spec .extra_tensor_info is not None
401- and spec .extra_tensor_info .fully_qualified_name is not None
402- and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
403- ):
404- assert (
405- constant_tag is not None
406- ), "Constant tag is not set for external tensor"
407-
408- buffer_idx = len (self .program_state .external_constant_buffer )
409- self .program_state .external_constant_hash [hashed ] = buffer_idx
410- self .program_state .external_constant_buffer .append (buffer_data )
411- if constant_tag not in self .program_state .external_constant_map :
412- self .program_state .external_constant_map [constant_tag ] = {}
413- self .program_state .external_constant_map [constant_tag ][
414- spec .extra_tensor_info .fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
415- ] = buffer_idx
416-
417- # Tensor is stored in the PTE file.
418- else :
419- buffer_idx = len (self .program_state .constant_buffer )
420- self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
421- self .program_state .constant_buffer .append (buffer )
417+ buffer_idx = len (self .program_state .constant_buffer )
418+ self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
419+ self .program_state .constant_buffer .append (buffer )
422420
423421 return buffer_idx
424422
@@ -458,7 +456,7 @@ def _tensor_spec_to_evalue(
458456
459457 hashed = hashlib .sha256 (buffer_data ).hexdigest ()
460458
461- if allocation_info :
459+ if allocation_info and spec . extra_tensor_info is None :
462460 buffer_idx = self .program_state .cached_spec_mutable_hash_values .get (
463461 hashed , - 1
464462 )
0 commit comments