112112)
113113from ..utils .import_utils import is_datasets_available , is_paddle_cuda_available
114114from ..utils .log import logger
115+ from ..utils .tools import get_env_device
115116from .argparser import strtobool
116117from .integrations import get_reporting_integration_callbacks
117118from .plugins .timer import RuntimeTimer , get_timers , set_timers
@@ -1773,10 +1774,6 @@ def apply_decay_param_fun(x):
17731774 return self .optimizer
17741775
17751776 def _apply_to_optimizer (self , action ):
1776- if "gpu" not in paddle .device .get_device ():
1777- logger .warning ("offload/reload optimizer's states is only supported on GPU devices." )
1778- return
1779-
17801777 attributes = [
17811778 ("_accumulators" , "_moment1_acc_str" ),
17821779 ("_accumulators" , "_moment2_acc_str" ),
@@ -1791,13 +1788,22 @@ def _apply_to_optimizer(self, action):
17911788 target_attr = target_attr [getattr (self .optimizer , attr [1 ])]
17921789
17931790 for key , value in target_attr .items ():
1794- target_attr [key ] = getattr (value , action )()
1791+ if get_env_device () == "gpu" :
1792+ target_attr [key ] = getattr (value , action )()
1793+ else :
1794+ target_attr [key ] = getattr (value , "to" )(action )
17951795
17961796 def _offload_optimizer (self ):
1797- self ._apply_to_optimizer ("pin_memory" )
1797+ if get_env_device () == "gpu" :
1798+ self ._apply_to_optimizer ("pin_memory" )
1799+ else :
1800+ self ._apply_to_optimizer ("cpu" )
17981801
17991802 def _reload_optimizer (self ):
1800- self ._apply_to_optimizer ("cuda" )
1803+ if get_env_device () == "gpu" :
1804+ self ._apply_to_optimizer ("cuda" )
1805+ else :
1806+ self ._apply_to_optimizer (get_env_device ())
18011807
18021808 def _load_rng_state (self , checkpoint ):
18031809 # Load RNG states from `checkpoint`
0 commit comments