@@ -588,20 +588,50 @@ def reset(key):
588588
589589 def get_extra_state (self ) -> torch .Tensor :
590590 """Save before checkpointing."""
591- state = None
592591
592+ # This implementation is working around a few issues:
593+ #
594+ # (1) PyTorch's "extra state" infrastructure might be able to
595+ # support any picklable type, but they make no guarantees.
596+ # We have experienced problems (e.g. in ONNX export) with
597+ # non-tensor extra state.
598+ # (2) PyTorch's checkpointing infrastructure does not remap
599+ # devices for "extra state" like it does for "state dict".
600+ # Thus, we want to avoid putting extra state on the GPU
601+ # since it may be loaded on the wrong device.
602+ # (3) The extra state consists of many small tensors. If we
603+ # want to copy them all to CPU, then we need to avoid the
604+ # overhead of many GPU-CPU memory transfers.
605+ #
606+ # See: https://github.com/NVIDIA/TransformerEngine/pull/351
607+ # See: https://github.com/NVIDIA/TransformerEngine/pull/363
608+
609+ def to_cpu (src : torch .Tensor ) -> torch .Tensor :
610+ """Helper function to make CPU copy of tensor
611+
612+ Memory transfer is asynchronous w.r.t. host, so GPU should
613+ be synchronized before using result.
614+
615+ """
616+ dst = torch .empty_like (src , device = "cpu" )
617+ dst .copy_ (src , non_blocking = True )
618+ return dst
619+
620+ # Store FP8 state if needed
621+ state = None
593622 fp8_checkpoint = self .fp8_meta ["fp8_checkpoint" ] or self .fp8 or self .fp8_calibration
594-
595623 if fp8_checkpoint :
624+
625+ # Copy tensors to CPU and store
596626 state = {}
597- state ["scale_fwd" ] = self .fp8_meta ["scaling_fwd" ].scale
598- state ["scale_inv_fwd " ] = self .fp8_meta ["scaling_fwd" ].scale_inv
599- state ["amax_history_fwd " ] = self .fp8_meta ["scaling_fwd" ].amax_history
600- state ["scale_bwd" ] = self .fp8_meta ["scaling_bwd" ].scale
601- state ["scale_inv_bwd " ] = self .fp8_meta ["scaling_bwd" ].scale_inv
602- state ["amax_history_bwd " ] = self .fp8_meta ["scaling_bwd" ].amax_history
603-
604- # Store other pickelable values.
627+ state ["scale_fwd" ] = to_cpu ( self .fp8_meta ["scaling_fwd" ].scale )
628+ state ["amax_history_fwd " ] = to_cpu ( self .fp8_meta ["scaling_fwd" ].amax_history )
629+ state ["scale_inv_fwd " ] = to_cpu ( self .fp8_meta ["scaling_fwd" ].scale_inv )
630+ state ["scale_bwd" ] = to_cpu ( self .fp8_meta ["scaling_bwd" ].scale )
631+ state ["amax_history_bwd " ] = to_cpu ( self .fp8_meta ["scaling_bwd" ].amax_history )
632+ state ["scale_inv_bwd " ] = to_cpu ( self .fp8_meta ["scaling_bwd" ].scale_inv )
633+
634+ # Store other pickelable values
605635 extra = {}
606636 for k , v in self .fp8_meta .items ():
607637 if k != "buffer_index_and_autocast_key" and isinstance (
@@ -610,22 +640,23 @@ def get_extra_state(self) -> torch.Tensor:
610640 extra [k ] = v
611641 state ["extra_fp8_variables" ] = extra
612642
613- if is_in_onnx_export_mode ():
614- state_serialized = torch .frombuffer (pickle .dumps (state ), dtype = torch .uint8 )
615- else :
616- state_serialized = io .BytesIO ()
617- torch .save (state , state_serialized )
618-
643+ # Serialize state into byte tensor
644+ torch .cuda .synchronize ()
645+ state_serialized = bytearray (pickle .dumps (state ))
646+ state_serialized = torch .frombuffer (state_serialized , dtype = torch .uint8 )
619647 return state_serialized
620648
621649 def set_extra_state (self , state : torch .Tensor ) -> None :
622650 """Load previous state."""
623651 if state is None :
624652 return
625653
654+ # Load state
626655 if isinstance (state , torch .Tensor ):
656+ # Default format: byte tensor with pickled data
627657 state = pickle .loads (state .detach ().cpu ().numpy ().tobytes ())
628658 elif isinstance (state , io .BytesIO ):
659+ # Deprecated format with io.BytesIO
629660 state .seek (0 )
630661 state = torch .load (state , map_location = "cuda" )
631662 else :
@@ -634,20 +665,32 @@ def set_extra_state(self, state: torch.Tensor) -> None:
634665 if state is None :
635666 return
636667
637- # Load extra items.
668+ # Load extra items
638669 self .fp8_meta .update (state ["extra_fp8_variables" ])
639670 self .fp8_meta ["recipe" ].amax_history_len = state ["amax_history_fwd" ].shape [0 ]
640671 if "global_fp8_buffer_pos_fwd_recompute" in self .fp8_meta :
641672 del self .fp8_meta ["global_fp8_buffer_pos_fwd_recompute" ]
642673
643- # Initialize before loading.
674+ # Initialize before loading
644675 self .init_fp8_meta_tensors ()
645- self .fp8_meta ["scaling_fwd" ].scale .copy_ (state ["scale_fwd" ])
646- self .fp8_meta ["scaling_fwd" ].amax_history .copy_ (state ["amax_history_fwd" ])
647- self .fp8_meta ["scaling_bwd" ].scale .copy_ (state ["scale_bwd" ])
648- self .fp8_meta ["scaling_bwd" ].amax_history .copy_ (state ["amax_history_bwd" ])
649- self .fp8_meta ["scaling_fwd" ].scale_inv .copy_ (state ["scale_inv_fwd" ])
650- self .fp8_meta ["scaling_bwd" ].scale_inv .copy_ (state ["scale_inv_bwd" ])
676+
677+ def copy_tensor (src : torch .Tensor , dst : torch .Tensor ) -> None :
678+ """Helper function to copy tensor from CPU
679+
680+ Memory transfer is asynchronous w.r.t. host, so GPU should
681+ be synchronized before using result.
682+
683+ """
684+ dst .copy_ (src , non_blocking = True )
685+
686+ # Load tensors
687+ copy_tensor (state ["scale_fwd" ], self .fp8_meta ["scaling_fwd" ].scale )
688+ copy_tensor (state ["amax_history_fwd" ], self .fp8_meta ["scaling_fwd" ].amax_history )
689+ copy_tensor (state ["scale_inv_fwd" ], self .fp8_meta ["scaling_fwd" ].scale_inv )
690+ copy_tensor (state ["scale_bwd" ], self .fp8_meta ["scaling_bwd" ].scale )
691+ copy_tensor (state ["amax_history_bwd" ], self .fp8_meta ["scaling_bwd" ].amax_history )
692+ copy_tensor (state ["scale_inv_bwd" ], self .fp8_meta ["scaling_bwd" ].scale_inv )
693+ torch .cuda .synchronize ()
651694
652695 def set_activation_dtype (self , inp : torch .Tensor ) -> None :
653696 """Get activation data type for AMP."""
0 commit comments