@@ -618,6 +618,19 @@ def __init__(
618618 "Refer to https://github.com/NVIDIA-NeMo/RL/issues/1164 for latest updates with this issue."
619619 )
620620
621+ optimizer_cpu_offload = self .cfg ["megatron_cfg" ]["optimizer" ][
622+ "optimizer_cpu_offload"
623+ ]
624+ optimizer_offload_fraction = self .cfg ["megatron_cfg" ]["optimizer" ][
625+ "optimizer_offload_fraction"
626+ ]
627+ if optimizer_cpu_offload :
628+ # Currently, hybrid optimizer (partly on GPU and partly on CPU) is not supported because it conflicts with the way
629+ # Nemo-rl handles the optimizer offload/onload between generation and training. So if using CPU optimizer the offload_fraction should be 1.0.
630+ assert optimizer_offload_fraction == 1.0 , (
631+ "Currently for optimizer offloading, only optimizer_offload_fraction=1.0 is supported"
632+ )
633+
621634 checkpoint_config = CheckpointConfig (
622635 save_interval = 100 ,
623636 save = weights_path ,
@@ -1759,7 +1772,11 @@ def prepare_for_training(self, *args, **kwargs):
17591772 self .model .train ()
17601773
17611774 # Move optimizer state to CUDA if it exists
1762- if hasattr (self , "optimizer" ) and self .optimizer is not None :
1775+ if (
1776+ hasattr (self , "optimizer" )
1777+ and self .optimizer is not None
1778+ and (not self .cfg ["megatron_cfg" ]["optimizer" ]["optimizer_cpu_offload" ])
1779+ ):
17631780 if isinstance (self .optimizer , ChainedOptimizer ):
17641781 optimizer_state = self .optimizer .state
17651782 else :
@@ -1786,7 +1803,11 @@ def offload_before_refit(self):
17861803 self .model , "cpu" , move_params = False , move_grads = True
17871804 ) # get rid of grad buffers
17881805 torch .randn (1 ).cuda () # wake up torch allocator
1789- if hasattr (self , "optimizer" ) and self .optimizer is not None :
1806+ if (
1807+ hasattr (self , "optimizer" )
1808+ and self .optimizer is not None
1809+ and (not self .cfg ["megatron_cfg" ]["optimizer" ]["optimizer_cpu_offload" ])
1810+ ):
17901811 # Iterate through the state dictionaries for each parameter group
17911812 if isinstance (self .optimizer , ChainedOptimizer ):
17921813 optimizer_state = self .optimizer .state
0 commit comments