|
48 | 48 | from ..utils.import_utils import is_paddle_cuda_available, is_psutil_available
|
49 | 49 | from ..utils.log import logger
|
50 | 50 | from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
|
| 51 | +from ..utils.tools import get_env_device |
51 | 52 | from .utils.helper import distributed_file
|
52 | 53 |
|
53 | 54 | __all__ = [
|
@@ -1252,3 +1253,31 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
|
1252 | 1253 | raise RuntimeError(
|
1253 | 1254 | f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download checkpoint from PDC, recovery_checkpoint_path: {recovery_checkpoint_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
|
1254 | 1255 | )
|
| 1256 | + |
| 1257 | + |
| 1258 | +def _insert_sync(self, sync_var, src, mp_group, sync_mode): |
| 1259 | + # Get device type where the sync_var is located |
| 1260 | + original_device = "pin_memory" if str(sync_var.place) == "Place(gpu_pinned)" else "Other" |
| 1261 | + |
| 1262 | + # If the sync_var is on pin memory, first move it to CUDA or other decives |
| 1263 | + if original_device == "pin_memory": |
| 1264 | + if get_env_device() == "gpu": |
| 1265 | + sync_var = sync_var.cuda() |
| 1266 | + else: |
| 1267 | + sync_var = sync_var.to(get_env_device()) |
| 1268 | + |
| 1269 | + if sync_mode == "broadcast": |
| 1270 | + paddle.distributed.broadcast(sync_var, src=src, group=mp_group, sync_op=True) |
| 1271 | + else: |
| 1272 | + paddle.distributed.all_reduce(sync_var, group=mp_group, sync_op=True) |
| 1273 | + sync_var.multiply_( |
| 1274 | + paddle.full( |
| 1275 | + shape=[], |
| 1276 | + dtype=sync_var.dtype, |
| 1277 | + fill_value=(1.0 / mp_group.nranks), |
| 1278 | + ) |
| 1279 | + ) |
| 1280 | + |
| 1281 | + # Move it back to pin memory |
| 1282 | + if original_device == "pin_memory": |
| 1283 | + sync_var = paddle.to_tensor(sync_var, place=paddle.CUDAPinnedPlace()) |
0 commit comments