|
15 | 15 | from swift import get_logger |
16 | 16 | from swift.hub import HubApi, ModelScopeConfig |
17 | 17 | from swift.utils import (add_version_to_work_dir, broadcast_string, |
18 | | - get_dist_setting, is_dist, is_master, is_mp) |
| 18 | + get_dist_setting, get_pai_tensorboard_dir, is_dist, |
| 19 | + is_master, is_mp, is_pai_training_job) |
19 | 20 | from .dataset import (DATASET_MAPPING, get_custom_dataset, get_dataset, |
20 | 21 | register_dataset) |
21 | 22 | from .model import (MODEL_MAPPING, dtype_mapping, |
@@ -52,7 +53,7 @@ class SftArguments: |
52 | 53 | f"template_type choices: {list(TEMPLATE_MAPPING.keys()) + ['AUTO']}" |
53 | 54 | }) |
54 | 55 | output_dir: str = 'output' |
55 | | - add_output_dir_suffix: bool = True |
| 56 | + add_output_dir_suffix: Optional[bool] = None |
56 | 57 | ddp_backend: Literal['nccl', 'gloo', 'mpi', 'ccl'] = 'nccl' |
57 | 58 |
|
58 | 59 | seed: int = 42 |
@@ -214,6 +215,8 @@ def prepare_target_modules(self, target_modules): |
214 | 215 |
|
215 | 216 | def __post_init__(self) -> None: |
216 | 217 | handle_compatibility(self) |
| 218 | + if is_pai_training_job(): |
| 219 | + handle_pai_compat(self) |
217 | 220 | ds_config_folder = os.path.join(__file__, '..', '..', 'ds_config') |
218 | 221 | if self.deepspeed_config_path == 'default-zero2': |
219 | 222 | self.deepspeed_config_path = os.path.abspath( |
@@ -270,6 +273,8 @@ def __post_init__(self) -> None: |
270 | 273 | if not dist.is_initialized(): |
271 | 274 | dist.init_process_group(backend=self.ddp_backend) |
272 | 275 |
|
| 276 | + if self.add_output_dir_suffix is None: |
| 277 | + self.add_output_dir_suffix = True |
273 | 278 | if self.add_output_dir_suffix: |
274 | 279 | self.output_dir = os.path.join(self.output_dir, self.model_type) |
275 | 280 | self.output_dir = add_version_to_work_dir(self.output_dir) |
@@ -907,3 +912,17 @@ def handle_dataset_mixture(args: SftArguments, train_dataset, |
907 | 912 | return concatenate_datasets([train_dataset, mixed_dataset]) |
908 | 913 | else: |
909 | 914 | return train_dataset |
| 915 | + |
| 916 | + |
| 917 | +def handle_pai_compat(args: SftArguments) -> None: |
| 918 | + assert is_pai_training_job() is True |
| 919 | + logger.info('Handle pai compat...') |
| 920 | + pai_tensorboard_dir = get_pai_tensorboard_dir() |
| 921 | + if args.logging_dir is None and pai_tensorboard_dir is not None: |
| 922 | + args.logging_dir = pai_tensorboard_dir |
| 923 | + logger.info(f'Setting args.logging_dir: {args.logging_dir}') |
| 924 | + if args.add_output_dir_suffix is None: |
| 925 | + args.add_output_dir_suffix = False |
| 926 | + logger.info( |
| 927 | + f'Setting args.add_output_dir_suffix: {args.add_output_dir_suffix}' |
| 928 | + ) |
0 commit comments