Skip to content

Commit 7a638ab

Browse files
committed
Merge commit '0248e8a1d38c0c73376ff8eca366869ed314cae4' into release/1.6
* commit '0248e8a1d38c0c73376ff8eca366869ed314cae4': fix doc (#376) Support PAI compat (#373)
2 parents c8a53f8 + 0248e8a commit 7a638ab

File tree

8 files changed

+75
-10
lines changed

8 files changed

+75
-10
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999

100100
### AdaLoRA微调参数
101101

102-
以下参数`sft_type`设置为`adalora`时生效。adalora的`target_modules`等参数继承于lora的对应参数。
102+
以下参数`sft_type`设置为`adalora`时生效。adalora的`target_modules`等参数继承于lora的对应参数,但`lora_dtype`参数不生效
103103

104104
- `--adalora_target_r`: 默认值8, adalora的平均rank
105105
- `--adalora_init_r`: 默认值12, adalora的初始rank

swift/llm/dpo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,7 @@ def llm_dpo(args: DPOArguments) -> str:
237237
if is_master():
238238
images_dir = os.path.join(args.output_dir, 'images')
239239
logger.info(f'images_dir: {images_dir}')
240-
tb_dir = os.path.join(args.output_dir, 'runs')
241-
plot_images(images_dir, tb_dir, ['train/loss'], 0.9)
240+
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
242241
if args.push_to_hub:
243242
trainer._add_patterns_to_gitignore(['images/'])
244243
trainer.push_to_hub()

swift/llm/sft.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
304304
if is_master():
305305
images_dir = os.path.join(args.output_dir, 'images')
306306
logger.info(f'images_dir: {images_dir}')
307-
tb_dir = os.path.join(args.output_dir, 'runs')
308-
plot_images(images_dir, tb_dir, ['train/loss'], 0.9)
307+
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
309308
if args.push_to_hub:
310309
trainer._add_patterns_to_gitignore(['images/'])
311310
trainer.push_to_hub()

swift/llm/utils/argument.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from swift import get_logger
1616
from swift.hub import HubApi, ModelScopeConfig
1717
from swift.utils import (add_version_to_work_dir, broadcast_string,
18-
get_dist_setting, is_dist, is_master, is_mp)
18+
get_dist_setting, get_pai_tensorboard_dir, is_dist,
19+
is_master, is_mp, is_pai_training_job)
1920
from .dataset import (DATASET_MAPPING, get_custom_dataset, get_dataset,
2021
register_dataset)
2122
from .model import (MODEL_MAPPING, dtype_mapping,
@@ -52,7 +53,7 @@ class SftArguments:
5253
f"template_type choices: {list(TEMPLATE_MAPPING.keys()) + ['AUTO']}"
5354
})
5455
output_dir: str = 'output'
55-
add_output_dir_suffix: bool = True
56+
add_output_dir_suffix: Optional[bool] = None
5657
ddp_backend: Literal['nccl', 'gloo', 'mpi', 'ccl'] = 'nccl'
5758

5859
seed: int = 42
@@ -214,6 +215,8 @@ def prepare_target_modules(self, target_modules):
214215

215216
def __post_init__(self) -> None:
216217
handle_compatibility(self)
218+
if is_pai_training_job():
219+
handle_pai_compat(self)
217220
ds_config_folder = os.path.join(__file__, '..', '..', 'ds_config')
218221
if self.deepspeed_config_path == 'default-zero2':
219222
self.deepspeed_config_path = os.path.abspath(
@@ -270,6 +273,8 @@ def __post_init__(self) -> None:
270273
if not dist.is_initialized():
271274
dist.init_process_group(backend=self.ddp_backend)
272275

276+
if self.add_output_dir_suffix is None:
277+
self.add_output_dir_suffix = True
273278
if self.add_output_dir_suffix:
274279
self.output_dir = os.path.join(self.output_dir, self.model_type)
275280
self.output_dir = add_version_to_work_dir(self.output_dir)
@@ -907,3 +912,17 @@ def handle_dataset_mixture(args: SftArguments, train_dataset,
907912
return concatenate_datasets([train_dataset, mixed_dataset])
908913
else:
909914
return train_dataset
915+
916+
917+
def handle_pai_compat(args: SftArguments) -> None:
918+
assert is_pai_training_job() is True
919+
logger.info('Handle pai compat...')
920+
pai_tensorboard_dir = get_pai_tensorboard_dir()
921+
if args.logging_dir is None and pai_tensorboard_dir is not None:
922+
args.logging_dir = pai_tensorboard_dir
923+
logger.info(f'Setting args.logging_dir: {args.logging_dir}')
924+
if args.add_output_dir_suffix is None:
925+
args.add_output_dir_suffix = False
926+
logger.info(
927+
f'Setting args.add_output_dir_suffix: {args.add_output_dir_suffix}'
928+
)

swift/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
get_model_info, is_ddp_plus_mp, is_dist,
1414
is_local_master, is_master, is_mp, is_on_same_device,
1515
seed_everything, show_layers, time_synchronize)
16-
from .utils import (add_version_to_work_dir, check_json_format, lower_bound,
16+
from .utils import (add_version_to_work_dir, check_json_format,
17+
get_pai_tensorboard_dir, is_pai_training_job, lower_bound,
1718
parse_args, read_multi_line, test_time, upper_bound)

swift/utils/utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import datetime as dt
33
import os
44
import re
5+
import sys
56
import time
67
from typing import (Any, Callable, List, Mapping, Optional, Sequence, Tuple,
78
Type, TypeVar)
@@ -67,8 +68,15 @@ def add_version_to_work_dir(work_dir: str) -> str:
6768
def parse_args(class_type: Type[_T],
6869
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
6970
parser = HfArgumentParser([class_type])
70-
args, remaining_args = parser.parse_args_into_dataclasses(
71-
argv, return_remaining_strings=True)
71+
if argv is None:
72+
argv = sys.argv[1:]
73+
if len(argv) > 0 and argv[0].endswith('.json'):
74+
json_path = os.path.abspath(os.path.expanduser(argv[0]))
75+
args, = parser.parse_json_file(json_path)
76+
remaining_args = argv[1:]
77+
else:
78+
args, remaining_args = parser.parse_args_into_dataclasses(
79+
argv, return_remaining_strings=True)
7280
return args, remaining_args
7381

7482

@@ -131,3 +139,11 @@ def read_multi_line() -> str:
131139
res[-1] = text[:-2]
132140
break
133141
return ''.join(res)
142+
143+
144+
def is_pai_training_job() -> bool:
145+
return 'PAI_TRAINING_JOB_ID' in os.environ
146+
147+
148+
def get_pai_tensorboard_dir() -> Optional[str]:
149+
return os.environ.get('PAI_OUTPUT_TENSORBOARD')

tests/llm/config/sft.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"model_type": "qwen-1_8b-chat",
3+
"dataset": "jd-sentiment-zh",
4+
"output_dir": "output/pai_test",
5+
"train_dataset_sample": 100,
6+
"eval_steps": 5
7+
}

tests/llm/test_run.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,30 @@ def test_dpo(self):
326326
load_dataset_config=True,
327327
val_dataset_sample=2))
328328

329+
def test_pai_compat(self):
330+
if not __name__ == '__main__':
331+
# ignore citest error in github
332+
return
333+
from swift.llm import sft_main, infer_main
334+
os.environ['PAI_TRAINING_JOB_ID'] = '123456'
335+
folder = os.path.join(os.path.dirname(__file__), 'config')
336+
tensorboard_dir = os.path.join('output/pai_test', 'pai_tensorboard')
337+
os.environ['PAI_OUTPUT_TENSORBOARD'] = tensorboard_dir
338+
sft_json = os.path.join(folder, 'sft.json')
339+
infer_json = os.path.join(folder, 'infer.json')
340+
output = sft_main([sft_json])
341+
print()
342+
infer_args = {
343+
'ckpt_dir': output['best_model_checkpoint'],
344+
'val_dataset_sample': 2,
345+
'load_dataset_config': True,
346+
}
347+
import json
348+
with open(infer_json, 'w') as f:
349+
json.dump(infer_args, f, ensure_ascii=False, indent=4)
350+
infer_main([infer_json])
351+
os.environ.pop('PAI_TRAINING_JOB_ID')
352+
329353

330354
def data_collate_fn(batch: List[Dict[str, Any]],
331355
tokenizer) -> Dict[str, Tensor]:

0 commit comments

Comments
 (0)