-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Merge dsv3 tainer part #2487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Merge dsv3 tainer part #2487
Changes from 4 commits
f789ca3
d57a21b
626ef5a
11318b9
6a72aed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
from paddle.distributed import fleet | ||
|
||
from ..utils.env import PREFIX_CHECKPOINT_DIR | ||
from ..utils.fault_tolerance import is_ft_env | ||
from ..utils.log import logger | ||
from ..utils.pdc_sdk import FLASH_DEVICE | ||
from .trainer_utils import ( | ||
|
@@ -1397,12 +1398,7 @@ def is_segment_parallel_supported(): | |
else: | ||
order = ["dp", "sharding", "pp", "mp"] | ||
if self.use_expert_parallel: | ||
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: | ||
order.insert(-1, "ep") | ||
sd_idx = order.index("sharding") | ||
# if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] | ||
# if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] | ||
order.insert(sd_idx, "moe_sharding") | ||
order = order[1:-1] + ["dp", "mp"] | ||
|
||
if is_segment_parallel_supported(): | ||
hybrid_configs = { | ||
|
@@ -1545,6 +1541,12 @@ def is_segment_parallel_supported(): | |
assert ( | ||
"split_param" in sharding_parallel_config | ||
), "split_param should be set when enable_stage1_allgather_overlap." | ||
use_casual_mask = os.getenv("USE_CASUAL_MASK", "False") | ||
|
||
assert use_casual_mask, "enable_stage1_allgather_overlap requires USE_CASUAL_MASK=True." | ||
assert self.logging_steps > 1, ( | ||
"The logging_steps should be greater than 1 for enable_stage1_allgather_overlap, " | ||
f"but got logging_steps={self.logging_steps}." | ||
) | ||
|
||
if "split_param" in sharding_parallel_config: | ||
if ShardingOption.SHARD_OP not in self.sharding: | ||
|
@@ -1556,6 +1558,9 @@ def is_segment_parallel_supported(): | |
fleet.init(is_collective=True, strategy=strategy) | ||
logger.info(strategy) | ||
|
||
if self.expert_parallel_degree > 1: | ||
self.add_moe_comm_group() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删掉的话会报错
|
||
|
||
elif self.enable_auto_parallel: | ||
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) | ||
self.sep_parallel_degree = max(self.sep_parallel_degree, 1) | ||
|
@@ -1902,32 +1907,12 @@ def is_segment_parallel_supported(): | |
self.refined_recompute = refined_recompute_dict | ||
|
||
# process fault tolerance settings | ||
pdc_zcc_init_step = os.getenv("PDC_FC_INIT_STEP") | ||
if pdc_zcc_init_step is not None and int(pdc_zcc_init_step) > 0: | ||
self.resume_from_checkpoint = os.path.join(FLASH_DEVICE, f"{PREFIX_CHECKPOINT_DIR}-{pdc_zcc_init_step}") | ||
logger.warning( | ||
f"PDC_FC_INIT_STEP {pdc_zcc_init_step} has been specified, automatically resume from FLASH_DEVICE: {self.resume_from_checkpoint}" | ||
) | ||
if self.flash_device_save_steps > 0: | ||
assert ( | ||
self.enable_zero_cost_checkpoint | ||
), "flash_device_save_steps should only be set in zero cost checkpoint save mode with flash device mounted." | ||
|
||
if self.enable_zero_cost_checkpoint: | ||
assert ( | ||
"enable_fuse_optimizer_states" in sharding_parallel_config | ||
), "zero cost checkpoint must be used when enable_fuse_optimizer_states is enabled in sharding parallel config" | ||
|
||
assert ( | ||
self.flash_device_save_steps % self.zcc_ema_interval == 0 | ||
), f"flash_device_save_steps[{self.flash_device_save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]" | ||
assert ( | ||
self.save_steps % self.zcc_ema_interval == 0 | ||
), f"save_steps[{self.save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]" | ||
if self.zcc_save_ema_coef is not None: | ||
assert ( | ||
self.zcc_workers_num == 1 | ||
), "EMA function in zero cost checkpoint mode does not support zcc_workers_num > 1 for now." | ||
if not is_ft_env(): | ||
if self.pdc_download_ckpt: | ||
logger.warning( | ||
"pdc_download_ckpt can only be set as true inside FT environment. Automatically disable it now." | ||
) | ||
self.pdc_download_ckpt = False | ||
|
||
|
||
def _post_init_parallel_degree(self): | ||
self.use_hybrid_parallel = False | ||
|
@@ -1994,6 +1979,11 @@ def _post_init_parallel_degree(self): | |
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") | ||
self.sharding = [] | ||
|
||
if sharding_parallel_degree > 1: | ||
assert ( | ||
sharding_parallel_degree % expert_parallel_degree == 0 | ||
), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}." | ||
|
||
self.data_parallel_degree = world_size // ( | ||
sharding_parallel_degree | ||
* tensor_parallel_degree | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要删除,删除之后会不会对原来逻辑有影响
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不修改的话会报错