Skip to content

Commit eae8d9f

Browse files
support offload/reload optimizer's states for custom device (#9467)
1 parent 0b6284e commit eae8d9f

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
)
113113
from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available
114114
from ..utils.log import logger
115+
from ..utils.tools import get_env_device
115116
from .argparser import strtobool
116117
from .integrations import get_reporting_integration_callbacks
117118
from .plugins.timer import RuntimeTimer, get_timers, set_timers
@@ -1773,10 +1774,6 @@ def apply_decay_param_fun(x):
17731774
return self.optimizer
17741775

17751776
def _apply_to_optimizer(self, action):
1776-
if "gpu" not in paddle.device.get_device():
1777-
logger.warning("offload/reload optimizer's states is only supported on GPU devices.")
1778-
return
1779-
17801777
attributes = [
17811778
("_accumulators", "_moment1_acc_str"),
17821779
("_accumulators", "_moment2_acc_str"),
@@ -1791,13 +1788,22 @@ def _apply_to_optimizer(self, action):
17911788
target_attr = target_attr[getattr(self.optimizer, attr[1])]
17921789

17931790
for key, value in target_attr.items():
1794-
target_attr[key] = getattr(value, action)()
1791+
if get_env_device() == "gpu":
1792+
target_attr[key] = getattr(value, action)()
1793+
else:
1794+
target_attr[key] = getattr(value, "to")(action)
17951795

17961796
def _offload_optimizer(self):
1797-
self._apply_to_optimizer("pin_memory")
1797+
if get_env_device() == "gpu":
1798+
self._apply_to_optimizer("pin_memory")
1799+
else:
1800+
self._apply_to_optimizer("cpu")
17981801

17991802
def _reload_optimizer(self):
1800-
self._apply_to_optimizer("cuda")
1803+
if get_env_device() == "gpu":
1804+
self._apply_to_optimizer("cuda")
1805+
else:
1806+
self._apply_to_optimizer(get_env_device())
18011807

18021808
def _load_rng_state(self, checkpoint):
18031809
# Load RNG states from `checkpoint`

0 commit comments

Comments
 (0)