112
112
)
113
113
from ..utils .import_utils import is_datasets_available , is_paddle_cuda_available
114
114
from ..utils .log import logger
115
+ from ..utils .tools import get_env_device
115
116
from .argparser import strtobool
116
117
from .integrations import get_reporting_integration_callbacks
117
118
from .plugins .timer import RuntimeTimer , get_timers , set_timers
@@ -1773,10 +1774,6 @@ def apply_decay_param_fun(x):
1773
1774
return self .optimizer
1774
1775
1775
1776
def _apply_to_optimizer (self , action ):
1776
- if "gpu" not in paddle .device .get_device ():
1777
- logger .warning ("offload/reload optimizer's states is only supported on GPU devices." )
1778
- return
1779
-
1780
1777
attributes = [
1781
1778
("_accumulators" , "_moment1_acc_str" ),
1782
1779
("_accumulators" , "_moment2_acc_str" ),
@@ -1791,13 +1788,22 @@ def _apply_to_optimizer(self, action):
1791
1788
target_attr = target_attr [getattr (self .optimizer , attr [1 ])]
1792
1789
1793
1790
for key , value in target_attr .items ():
1794
- target_attr [key ] = getattr (value , action )()
1791
+ if get_env_device () == "gpu" :
1792
+ target_attr [key ] = getattr (value , action )()
1793
+ else :
1794
+ target_attr [key ] = getattr (value , "to" )(action )
1795
1795
1796
1796
def _offload_optimizer (self ):
1797
- self ._apply_to_optimizer ("pin_memory" )
1797
+ if get_env_device () == "gpu" :
1798
+ self ._apply_to_optimizer ("pin_memory" )
1799
+ else :
1800
+ self ._apply_to_optimizer ("cpu" )
1798
1801
1799
1802
def _reload_optimizer (self ):
1800
- self ._apply_to_optimizer ("cuda" )
1803
+ if get_env_device () == "gpu" :
1804
+ self ._apply_to_optimizer ("cuda" )
1805
+ else :
1806
+ self ._apply_to_optimizer (get_env_device ())
1801
1807
1802
1808
def _load_rng_state (self , checkpoint ):
1803
1809
# Load RNG states from `checkpoint`
0 commit comments