Skip to content

Commit 215e3e4

Browse files
authored
Solve DPO pin-memory problem by hacking HybridParallelOptimizer (#2428)
1 parent 9f42fc3 commit 215e3e4

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

paddleformers/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
ShardingOption,
159159
TrainerMemoryTracker,
160160
TrainOutput,
161+
_insert_sync,
161162
download_recovery_ckpt_from_pdc,
162163
find_batch_size,
163164
get_last_checkpoint,
@@ -2413,6 +2414,9 @@ def get_expected_keys(inputs, keys):
24132414
):
24142415
self.optimizer._set_broadcast_overlap(True, model)
24152416

2417+
# To solve DPO pin-memory problem, temporarily modify the _insert_sync method.
2418+
self.optimizer._insert_sync = types.MethodType(_insert_sync, self.optimizer)
2419+
24162420
return model
24172421

24182422
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:

paddleformers/trainer/trainer_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from ..utils.import_utils import is_paddle_cuda_available, is_psutil_available
4949
from ..utils.log import logger
5050
from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
51+
from ..utils.tools import get_env_device
5152
from .utils.helper import distributed_file
5253

5354
__all__ = [
@@ -1252,3 +1253,31 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
12521253
raise RuntimeError(
12531254
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download checkpoint from PDC, recovery_checkpoint_path: {recovery_checkpoint_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
12541255
)
1256+
1257+
1258+
def _insert_sync(self, sync_var, src, mp_group, sync_mode):
1259+
# Get device type where the sync_var is located
1260+
original_device = "pin_memory" if str(sync_var.place) == "Place(gpu_pinned)" else "Other"
1261+
1262+
# If the sync_var is on pin memory, first move it to CUDA or other decives
1263+
if original_device == "pin_memory":
1264+
if get_env_device() == "gpu":
1265+
sync_var = sync_var.cuda()
1266+
else:
1267+
sync_var = sync_var.to(get_env_device())
1268+
1269+
if sync_mode == "broadcast":
1270+
paddle.distributed.broadcast(sync_var, src=src, group=mp_group, sync_op=True)
1271+
else:
1272+
paddle.distributed.all_reduce(sync_var, group=mp_group, sync_op=True)
1273+
sync_var.multiply_(
1274+
paddle.full(
1275+
shape=[],
1276+
dtype=sync_var.dtype,
1277+
fill_value=(1.0 / mp_group.nranks),
1278+
)
1279+
)
1280+
1281+
# Move it back to pin memory
1282+
if original_device == "pin_memory":
1283+
sync_var = paddle.to_tensor(sync_var, place=paddle.CUDAPinnedPlace())

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
blobfile
22
colorlog
3-
seqeval
3+
scikit-learn
44
multiprocess<=0.70.12.2
55
datasets >= 2.0.0
66
tqdm

0 commit comments

Comments
 (0)