|
26 | 26 | from paddle.io import (
|
27 | 27 | DataLoader,
|
28 | 28 | )
|
29 |
| - |
| 29 | +import paddle.distributed as dist |
| 30 | +from paddle.distributed import fleet |
| 31 | +import functools |
30 | 32 | from deepmd.common import (
|
31 | 33 | symlink_prefix_files,
|
32 | 34 | )
|
@@ -101,6 +103,11 @@ def __init__(
|
101 | 103 | Args:
|
102 | 104 | - config: The Dict-like configuration with training options.
|
103 | 105 | """
|
| 106 | + from paddle.distributed import fleet |
| 107 | + mesh_dims = [("dp", 32)] |
| 108 | + fleet.auto.create_mesh(mesh_dims) |
| 109 | + fleet.init(is_collective=True) |
| 110 | + |
104 | 111 | enable_prim(True)
|
105 | 112 | if init_model is not None:
|
106 | 113 | resume_model = init_model
|
@@ -748,22 +755,39 @@ def step(_step_id, task_key="Default") -> None:
|
748 | 755 | if self.world_size > 1
|
749 | 756 | else contextlib.nullcontext
|
750 | 757 | )
|
751 |
| - with sync_context(): |
752 |
| - with nvprof_context(enable_profiling, "Forward pass"): |
753 |
| - model_pred, loss, more_loss = self.wrapper( |
754 |
| - **input_dict, |
755 |
| - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), |
756 |
| - label=label_dict, |
757 |
| - task_key=task_key, |
758 |
| - ) |
759 |
| - |
760 |
| - with nvprof_context(enable_profiling, "Backward pass"): |
761 |
| - loss.backward() |
| 758 | + |
| 759 | + # with sync_context(): |
| 760 | + # with nvprof_context(enable_profiling, "Forward pass"): |
| 761 | + # model_pred, loss, more_loss = self.wrapper( |
| 762 | + # **input_dict, |
| 763 | + # cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), |
| 764 | + # label=label_dict, |
| 765 | + # task_key=task_key, |
| 766 | + # ) |
| 767 | + |
| 768 | + # with nvprof_context(enable_profiling, "Backward pass"): |
| 769 | + # loss.backward() |
| 770 | + |
| 771 | + # if self.world_size > 1: |
| 772 | + # # fuse + allreduce manually before optimization if use DDP + no_sync |
| 773 | + # # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 |
| 774 | + # hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None) |
| 775 | + |
| 776 | + with nvprof_context(enable_profiling, "Forward pass"): |
| 777 | + for __key in ('coord', 'atype', 'box'): |
| 778 | + input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)]) |
| 779 | + for __key, _ in label_dict.items(): |
| 780 | + if isinstance(label_dict[__key], paddle.Tensor): |
| 781 | + label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)]) |
| 782 | + model_pred, loss, more_loss = self.wrapper( |
| 783 | + **input_dict, |
| 784 | + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), |
| 785 | + label=label_dict, |
| 786 | + task_key=task_key, |
| 787 | + ) |
762 | 788 |
|
763 |
| - if self.world_size > 1: |
764 |
| - # fuse + allreduce manually before optimization if use DDP + no_sync |
765 |
| - # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 |
766 |
| - hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None) |
| 789 | + with nvprof_context(enable_profiling, "Backward pass"): |
| 790 | + loss.backward() |
767 | 791 |
|
768 | 792 | if self.gradient_max_norm > 0.0:
|
769 | 793 | with nvprof_context(enable_profiling, "Gradient clip"):
|
|
0 commit comments