Skip to content

Commit 1e9f8be

Browse files
authored
[TorchAcc][Experimental] Integrate TorchAcc. (#647)
1 parent 985eea3 commit 1e9f8be

File tree

16 files changed

+606
-24
lines changed

16 files changed

+606
-24
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Experimental environment: 4 * 8*A100
2+
# 80GB GPU memory
3+
# Note: TorchAcc is currently only available internally.
4+
5+
export USE_TORCHACC=1
6+
export XLA_FLAGS='--xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner'
7+
export XLA_IR_SHAPE_CACHE_SIZE=100000000
8+
export XLA_ALLOCATOR_FRACTION=0.97
9+
10+
# Note: You need to set the correct MASTER_ADDR, MASTER_PORT and NODE_RANK for each node.
11+
12+
MASTER_ADDR=127.0.0.1 \
13+
MASTER_PORT=12456 \
14+
NODE_RANK=0 \
15+
NNODES=4 \
16+
NPROC_PER_NODE=8 \
17+
swift sft \
18+
--model_type qwen-72b-chat \
19+
--model_layer_cls_name QWenBlock \
20+
--dataset codefuse-python-en \
21+
--sft_type full \
22+
--output_dir output \
23+
--num_train_epochs 1 \
24+
--max_length 1024 \
25+
--batch_size 1 \
26+
--use_flash_attn true \
27+
--gradient_accumulation_steps 1 \
28+
--gradient_checkpointing no \
29+
--tuner_backend 'peft' \
30+
--eval_steps 200 \
31+
--save_steps 200 \
32+
--logging_steps 100 \
33+
--report_to 'none'
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Experimental environment: 4 * A800
2+
# 80GB GPU memory
3+
# Note: TorchAcc is currently only available internally.
4+
5+
export USE_TORCHACC=1
6+
export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization'
7+
export XLA_IR_SHAPE_CACHE_SIZE=100000000
8+
export XLA_ALLOCATOR_FRACTION=0.95
9+
export XLA_EXPERIMENTAL=nonzero:masked_select
10+
11+
NPROC_PER_NODE=4 \
12+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
13+
swift sft \
14+
--model_type qwen-72b-chat \
15+
--model_layer_cls_name QWenBlock \
16+
--dataset codefuse-python-en \
17+
--sft_type lora \
18+
--output_dir output_qwen_72b \
19+
--num_train_epochs 1 \
20+
--max_length 2048 \
21+
--batch_size 6 \
22+
--use_flash_attn true \
23+
--gradient_accumulation_steps 1 \
24+
--gradient_checkpointing no \
25+
--tuner_backend 'peft' \
26+
--eval_steps 200 \
27+
--save_steps 200 \
28+
--logging_steps 100 \
29+
--report_to 'none' \

swift/llm/accelerator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
3+
4+
def ta_accelerate(model,
5+
fsdp_num,
6+
layer_cls_name,
7+
bf16=True,
8+
fp16=False,
9+
gradient_checkpointing=True,
10+
fsdp_flatten_parameters=False):
11+
""" accelerate LLM training using TorchAcc(only available internally).
12+
"""
13+
import torchacc as ta
14+
assert layer_cls_name is not None
15+
16+
def get_ta_config():
17+
config = ta.Config()
18+
config.compute.fp16 = fp16
19+
config.compute.bf16 = bf16
20+
21+
config.memory.gc = gradient_checkpointing
22+
if config.memory.gc:
23+
config.memory.gc_cls = {layer_cls_name}
24+
25+
config.dist.fsdp.size = fsdp_num
26+
config.dist.fsdp.wrap_layer_cls = {layer_cls_name}
27+
config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters
28+
29+
return config
30+
31+
ta_config = get_ta_config()
32+
model = ta.accelerate(model, ta_config)
33+
return model

swift/llm/sft.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
from transformers.utils import is_torch_npu_available
1313

1414
from swift.trainers import Seq2SeqTrainer
15+
from swift.trainers.utils import can_return_loss, find_labels
1516
from swift.utils import (check_json_format, compute_acc_metrics,
1617
compute_nlg_metrics, get_dist_setting, get_logger,
1718
get_main, get_model_info, is_ddp_plus_mp, is_dist,
1819
is_master, plot_images, preprocess_logits_for_metrics,
19-
seed_everything, show_layers)
20+
seed_everything, show_layers, use_torchacc)
21+
from .accelerator import ta_accelerate
2022
from .tuner import prepare_model
2123
from .utils import (TEMPLATE_MAPPING, LazyLLMDataset, SftArguments, Template,
2224
add_self_cognition_dataset, dataset_map, get_dataset,
@@ -55,15 +57,15 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
5557
model_kwargs = {'low_cpu_mem_usage': True}
5658
if is_dist() and not is_ddp_plus_mp():
5759
model_kwargs['device_map'] = {'': local_rank}
58-
else:
60+
elif not use_torchacc():
5961
model_kwargs['device_map'] = 'auto'
62+
6063
if args.load_in_8bit or args.load_in_4bit:
6164
quantization_config = BitsAndBytesConfig(
6265
args.load_in_8bit,
6366
args.load_in_4bit,
6467
bnb_4bit_compute_dtype=args.bnb_4bit_compute_dtype,
6568
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
66-
bnb_4bit_quant_storage=args.bnb_4bit_quant_storage,
6769
bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant)
6870
logger.info(f'quantization_config: {quantization_config.__dict__}')
6971
model_kwargs['quantization_config'] = quantization_config
@@ -93,6 +95,13 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
9395
set_generation_config(model, generation_config)
9496
training_args.generation_config = generation_config
9597

98+
if use_torchacc():
99+
import torchacc as ta
100+
# Get `label` and `return_loss` before 'ta_accelerate' because it will
101+
# wrapper the model and make these properties wrong.
102+
label_names = find_labels(model)
103+
return_loss = can_return_loss(model)
104+
model = ta.patch_qwen_model(model)
96105
# Preparing LoRA
97106
model, callbacks = prepare_model(model, args)
98107

@@ -108,6 +117,18 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
108117
logger.info('Setting model.config.use_cache: False')
109118
model.enable_input_require_grads()
110119

120+
if use_torchacc():
121+
model.config.use_cache = False
122+
logger.info('Setting model.config.use_cache: False')
123+
model = ta_accelerate(
124+
model,
125+
world_size,
126+
args.model_layer_cls_name,
127+
args.bf16,
128+
args.fp16,
129+
gradient_checkpointing=True,
130+
fsdp_flatten_parameters=False)
131+
111132
# Loading Dataset
112133
random_state = np.random.RandomState(args.dataset_seed)
113134
train_dataset, val_dataset = get_dataset(
@@ -185,6 +206,15 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
185206
padding_to = args.max_length if args.sft_type == 'longlora' else None
186207
data_collator = partial(template.data_collator, padding_to=padding_to)
187208

209+
trian_batch_size = args.batch_size
210+
eval_batch_size = args.eval_batch_size
211+
if use_torchacc():
212+
trian_batch_size *= world_size
213+
eval_batch_size *= world_size
214+
training_args.per_device_train_batch_size = trian_batch_size
215+
training_args.per_device_eval_batch_size = eval_batch_size
216+
training_args.group_by_length = use_torchacc()
217+
188218
# Trainer
189219
logger.info(f'training_args: {training_args}')
190220

@@ -211,6 +241,9 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
211241
callbacks=callbacks,
212242
**trainer_kwargs)
213243
trainer.sft_args = args
244+
if use_torchacc():
245+
trainer.label_names = label_names
246+
trainer.can_return_loss = return_loss
214247
if is_master():
215248
for args_obj, fname in zip([args, training_args],
216249
['sft_args.json', 'training_args.json']):
@@ -233,7 +266,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
233266
f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
234267
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
235268
# Visualization
236-
if is_master():
269+
if is_master() and not use_torchacc():
237270
images_dir = os.path.join(args.output_dir, 'images')
238271
logger.info(f'images_dir: {images_dir}')
239272
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
@@ -253,4 +286,14 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
253286
}
254287

255288

256-
sft_main = get_main(SftArguments, llm_sft)
289+
def get_sft_main(args, llm):
290+
if use_torchacc():
291+
logger.warning('TorchAcc is currently only available internally '
292+
'within Alibaba Cloud.')
293+
import torchacc as ta
294+
# This patch should be called before `llm_sft`.
295+
ta.accelerate_hf_trainer()
296+
return get_main(args, llm)
297+
298+
299+
sft_main = get_sft_main(SftArguments, llm_sft)

swift/llm/tuner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import os
23
import types
34

45
import torch
56
import transformers
67
from packaging import version
78

9+
from swift.torchacc_utils import consolidate_checkpoint
810
from swift.trainers import TrainerCallback
911
from swift.tuners import (AdaLoraConfig, IA3Config, LongLoRAConfig,
1012
LongLoRAModelType, LoraConfig, LoRAConfig,
1113
NEFTuneConfig, Swift)
1214
from swift.tuners.llamapro import LLaMAProConfig
1315
from swift.tuners.module_mapping import MODEL_KEYS_MAPPING
1416
from swift.utils import (activate_model_parameters, freeze_model_parameters,
15-
get_logger)
17+
get_logger, use_torchacc)
1618
from .utils import (SftArguments, find_all_linears, find_embedding, find_ln,
1719
is_adapter)
1820

@@ -149,6 +151,9 @@ def prepare_model(model, args: SftArguments):
149151
model = Swift.prepare_model(model, llamapro_config)
150152
logger.info(f'llamapro_config: {llamapro_config}')
151153
else:
154+
if use_torchacc():
155+
consolidate_checkpoint(args.resume_from_checkpoint,
156+
'adapter_model')
152157
model = Swift.from_pretrained(
153158
model, args.resume_from_checkpoint, is_trainable=True)
154159
# fix bug: Attempting to unscale FP16 gradients.
@@ -168,6 +173,14 @@ def prepare_model(model, args: SftArguments):
168173
if len(args.additional_trainable_parameters) > 0:
169174
activate_model_parameters(model,
170175
args.additional_trainable_parameters)
176+
if use_torchacc() and args.resume_from_checkpoint is not None:
177+
consolidate_checkpoint(args.resume_from_checkpoint, 'model')
178+
weights_file = os.path.join(args.resume_from_checkpoint,
179+
'model.bin')
180+
state_dict = torch.load(weights_file, map_location='cpu')
181+
model.load_state_dict(state_dict, False)
182+
# release memory
183+
del state_dict
171184
else:
172185
raise ValueError(f'args.sft_type: {args.sft_type}')
173186

swift/llm/utils/argument.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ class SftArguments:
4949
metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
5050
model_id_or_path: Optional[str] = None
5151
model_revision: Optional[str] = None
52+
model_layer_cls_name: Optional[str] = field(
53+
default=None,
54+
metadata={
55+
'help':
56+
"Decoder Class name of model, e.g. 'QWenBlock' for QWen, 'LlamaDecoderLayer' for LLama"
57+
})
5258

5359
sft_type: Literal['lora', 'full', 'longlora', 'qalora', 'adalora', 'ia3',
5460
'llamapro'] = 'lora'

swift/llm/utils/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from transformers.utils.versions import require_version
2727

2828
from swift import get_logger
29-
from swift.utils import is_dist, is_local_master
29+
from swift.utils import is_dist, is_local_master, use_torchacc
3030
from .template import TemplateType
3131
from .utils import get_max_model_len
3232

@@ -2952,7 +2952,7 @@ def get_model_tokenizer(
29522952
get_function = model_info['get_function']
29532953
if model_kwargs is None:
29542954
model_kwargs = {}
2955-
if 'device_map' not in model_kwargs:
2955+
if 'device_map' not in model_kwargs and not use_torchacc():
29562956
model_kwargs['device_map'] = 'auto'
29572957

29582958
if model_info.get('torch_dtype') is not None:

swift/llm/utils/template.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from transformers import PreTrainedTokenizerBase, StoppingCriteria
1212

1313
from swift.llm.agent.utils import calculate_loss_scale
14+
from swift.torchacc_utils import pad_and_split_batch
15+
from swift.utils import get_dist_setting, use_torchacc
1416

1517
DEFAULT_SYSTEM = 'You are a helpful assistant.'
1618
History = List[Union[Tuple[str, str], List[str]]]
@@ -429,12 +431,18 @@ def data_collator(self,
429431
loss_scale, batch_first=True, padding_value=0.)
430432
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
431433

434+
if use_torchacc():
435+
rank, _, world_size, _ = get_dist_setting()
436+
input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
437+
padding_to, input_ids, attention_mask, labels, loss_scale,
438+
self.max_length, self.tokenizer, rank, world_size)
439+
432440
res = {
433441
'input_ids': input_ids,
434442
'attention_mask': attention_mask,
435443
'labels': labels,
436444
}
437-
if loss_scale is not None:
445+
if loss_scale:
438446
res['loss_scale'] = loss_scale
439447
return res
440448

swift/llm/utils/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import shutil
8+
import sys
89
from copy import deepcopy
910
from functools import partial, wraps
1011
from queue import Empty, Queue
@@ -40,7 +41,8 @@
4041
from swift.hub import ModelScopeConfig
4142
from swift.tuners.module_mapping import MODEL_KEYS_MAPPING
4243
from swift.utils import (get_dist_setting, get_logger, is_ddp_plus_mp, is_dist,
43-
is_local_master, is_master, stat_array, upper_bound)
44+
is_local_master, is_master, stat_array, upper_bound,
45+
use_torchacc)
4446
from .template import History, StopWords, StopWordsCriteria, Template
4547

4648
logger = get_logger()
@@ -868,6 +870,8 @@ def get_max_model_len(config: PretrainedConfig) -> Optional[int]:
868870
_old_ddp_init(self, model, *args, **kwargs))
869871
transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None
870872
transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch
873+
874+
if is_ddp_plus_mp() or use_torchacc():
871875
_old_accelerator_init = trainer.Accelerator.__init__
872876
trainer.Accelerator.__init__ = (
873877
lambda self, device_placement=False, *args, **kwargs:

0 commit comments

Comments
 (0)