Skip to content

Commit b0e0fb0

Browse files
Add script for ROME (#123)
1 parent 074178d commit b0e0fb0

File tree

11 files changed

+175
-10
lines changed

11 files changed

+175
-10
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[
2+
{
3+
"prompt": "{} was the founder of",
4+
"subject": "Steve Jobs",
5+
"target": "Microsoft"
6+
},
7+
{
8+
"prompt": "{} is located in",
9+
"subject": "HangZhou",
10+
"target": "Africa"
11+
}
12+
]

examples/pytorch/llm/rome_infer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
3+
from swift.llm.run import rome_main
4+
5+
if __name__ == '__main__':
6+
rome_main()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Experimental environment: A10
2+
PYTHONPATH=../../.. \
3+
CUDA_VISIBLE_DEVICES=0 \
4+
python rome_infer.py \
5+
--model_id_or_path modelscope/Llama-2-13b-chat-ms \
6+
--model_revision master \
7+
--template_type llama \
8+
--dtype bf16 \
9+
--eval_human true \
10+
--max_new_tokens 128 \
11+
--temperature 0.1 \
12+
--top_k 50 \
13+
--top_p 0.9 \
14+
--do_sample true \
15+
--rome_request_file rome_example/request.json

swift/llm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from .infer import llm_infer
3+
from .rome import rome_infer
34
from .sft import llm_sft
45
from .utils import *

swift/llm/rome.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import json
3+
import torch
4+
from modelscope import GenerationConfig
5+
6+
from swift.tuners import Swift
7+
from swift.utils import (get_logger, print_model_info, seed_everything,
8+
show_layers)
9+
from ..tuners.rome import RomeConfig
10+
from .utils import (RomeArguments, Template, get_dataset, get_model_tokenizer,
11+
get_template, inference)
12+
13+
logger = get_logger()
14+
15+
16+
def rome_infer(args: RomeArguments) -> None:
17+
logger.info(f'args: {args}')
18+
logger.info(
19+
'Rome does not support quantization for now, all quantization args will be ignored.'
20+
)
21+
logger.info(f'device_count: {torch.cuda.device_count()}')
22+
seed_everything(args.seed)
23+
24+
# ### Loading Model and Tokenizer
25+
model_kwargs = {'low_cpu_mem_usage': True, 'device_map': 'auto'}
26+
kwargs = {'use_flash_attn': args.use_flash_attn}
27+
model, tokenizer = get_model_tokenizer(args.model_type, args.torch_dtype,
28+
model_kwargs, **kwargs)
29+
30+
with open(args.rome_request_file, 'r') as f:
31+
request = json.load(f)
32+
33+
rome_type: str = None
34+
if args.model_type in ('llama2-13b-chat', 'llama2-13b', 'llama-13b-chat',
35+
'llama-13b'):
36+
rome_type = 'llama-13b'
37+
elif args.model_type in ('llama2-7b-chat', 'llama2-7b', 'llama-7b-chat',
38+
'llama-7b'):
39+
rome_type = 'llama-7b'
40+
41+
config = RomeConfig(
42+
model_type=rome_type,
43+
knowledge=request,
44+
tokenizer=tokenizer,
45+
)
46+
model = Swift.prepare_model(model, config, inference_mode=True)
47+
48+
show_layers(model)
49+
print_model_info(model)
50+
51+
# ### Inference
52+
template: Template = get_template(args.template_type, tokenizer,
53+
args.system, args.max_length)
54+
generation_config = GenerationConfig(
55+
max_length=None,
56+
max_new_tokens=args.max_new_tokens,
57+
temperature=args.temperature,
58+
top_k=args.top_k,
59+
do_sample=args.do_sample,
60+
repetition_penalty=args.repetition_penalty,
61+
pad_token_id=tokenizer.pad_token_id,
62+
eos_token_id=tokenizer.eos_token_id)
63+
logger.info(f'generation_config: {generation_config}')
64+
if args.overwrite_generation_config:
65+
generation_config.save_pretrained(args.ckpt_dir)
66+
model.generation_config = generation_config
67+
68+
if args.eval_human:
69+
while True:
70+
query = input('<<< ')
71+
data = {'query': query}
72+
input_ids = template.encode(data)['input_ids']
73+
inference(input_ids, model, tokenizer, args.stream)
74+
else:
75+
_, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio,
76+
args.dataset_seed)
77+
mini_val_dataset = val_dataset.select(
78+
range(min(args.show_dataset_sample, val_dataset.shape[0])))
79+
for data in mini_val_dataset:
80+
response = data['response']
81+
data['response'] = None
82+
input_ids = template.encode(data)['input_ids']
83+
inference(input_ids, model, tokenizer, args.stream)
84+
print()
85+
print(f'[LABELS]{response}')
86+
print('-' * 80)
87+
# input('next[ENTER]')

swift/llm/run.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from swift.llm import (InferArguments, SftArguments, get_main, llm_infer,
3-
llm_sft)
2+
from swift.llm import (InferArguments, RomeArguments, SftArguments, get_main,
3+
llm_infer, llm_sft, rome_infer)
44

55
sft_main = get_main(SftArguments, llm_sft)
66
infer_main = get_main(InferArguments, llm_infer)
7+
rome_main = get_main(RomeArguments, rome_infer)

swift/llm/sft.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def llm_sft(args: SftArguments) -> str:
5454
model_kwargs, **kwargs)
5555

5656
# ### Preparing LoRA
57-
if args.sft_type == 'lora' or args.sft_type == 'longlora':
57+
if args.sft_type in ('lora', 'qalora', 'longlora'):
5858
if args.resume_from_checkpoint is None:
5959
if 'ALL' in args.lora_target_modules:
6060
assert len(args.lora_target_modules) == 1
@@ -88,6 +88,20 @@ def llm_sft(args: SftArguments) -> str:
8888
use_flash_attn=args.use_flash_attn)
8989
model = Swift.prepare_model(model, longlora_config)
9090
logger.info(f'longlora_config: {longlora_config}')
91+
elif args.sft_type == 'qalora':
92+
assert getattr(
93+
model, 'quantization_method',
94+
None) == 'gptq', 'qalora must be used with auto_gptq'
95+
lora_kwargs = {}
96+
lora_config = LoRAConfig(
97+
r=args.lora_rank,
98+
target_modules=args.lora_target_modules,
99+
lora_alpha=args.lora_alpha,
100+
lora_dropout=args.lora_dropout_p,
101+
use_qa_lora=True,
102+
**lora_kwargs)
103+
model = Swift.prepare_model(model, lora_config)
104+
logger.info(f'lora_config: {lora_config}')
91105
else:
92106
model = Swift.from_pretrained(
93107
model, args.resume_from_checkpoint, is_trainable=True)

swift/llm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from .argument import InferArguments, SftArguments
2+
from .argument import InferArguments, RomeArguments, SftArguments
33
from .dataset import (DATASET_MAPPING, AlpacaPreprocessor,
44
ConversationsPreprocessor, DatasetName,
55
GetDatasetFunction, get_dataset, get_dataset_from_repo,

swift/llm/utils/argument.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class SftArguments:
3131
model_cache_dir: Optional[str] = None
3232

3333
sft_type: str = field(
34-
default='lora', metadata={'choices': ['longlora', 'lora', 'full']})
34+
default='lora',
35+
metadata={'choices': ['lora', 'longlora', 'qalora', 'full']})
3536
tuner_backend: str = field(
3637
default='swift', metadata={'choices': ['swift', 'peft']})
3738
template_type: Optional[str] = field(
@@ -158,7 +159,7 @@ def init_argument(self):
158159
# Make sure to set the same output_dir when using DDP.
159160
self.output_dir = broadcast_string(self.output_dir)
160161

161-
if self.sft_type == 'lora' or self.sft_type == 'longlora':
162+
if self.sft_type in ('lora', 'longlora', 'qalora'):
162163
if self.learning_rate is None:
163164
self.learning_rate = 1e-4
164165
if self.only_save_model is None:
@@ -224,7 +225,8 @@ class InferArguments:
224225
model_revision: Optional[str] = None
225226

226227
sft_type: str = field(
227-
default='lora', metadata={'choices': ['longlora', 'lora', 'full']})
228+
default='lora',
229+
metadata={'choices': ['lora', 'longlora', 'qalora', 'full']})
228230
template_type: Optional[str] = field(
229231
default=None,
230232
metadata={
@@ -291,6 +293,33 @@ def init_argument(self):
291293
self.max_length = None
292294

293295

296+
@dataclass
297+
class RomeArguments(InferArguments):
298+
299+
rome_request_file: str = field(
300+
default=None,
301+
metadata={
302+
'help':
303+
'The rome request file, please check the documentation '
304+
'to get the format'
305+
})
306+
307+
def init_argument(self):
308+
# Can be manually initialized, unlike __post_init__
309+
handle_compatibility(self)
310+
set_model_type(self)
311+
handle_dir(self)
312+
313+
self.torch_dtype, _, _ = select_dtype(self)
314+
if self.template_type is None:
315+
self.template_type = MODEL_MAPPING[self.model_type]['template']
316+
logger.info(f'Setting template_type: {self.template_type}')
317+
318+
assert isinstance(self.dataset, (list, tuple))
319+
if self.max_length == -1:
320+
self.max_length = None
321+
322+
294323
dtype_mapping_reversed = {v: k for k, v in dtype_mapping.items()}
295324

296325

swift/tuners/rome/rome.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def execute_rome(
147147
layer,
148148
context_template,
149149
)
150-
logger.info('Left vector shape:', left_vector.shape)
150+
logger.info(f'Left vector shape: {left_vector.shape}')
151151
right_vector: torch.Tensor = compute_v(
152152
model,
153153
tok,
@@ -157,7 +157,7 @@ def execute_rome(
157157
left_vector,
158158
context_template,
159159
)
160-
logger.info('Right vector shape:', right_vector.shape)
160+
logger.info(f'Right vector shape: {right_vector.shape}')
161161
right_vector = right_vector.to(left_vector.dtype)
162162

163163
with torch.no_grad():

0 commit comments

Comments
 (0)