Skip to content

Commit 24a71e4

Browse files
Support ROME (#121)
1 parent fa282f9 commit 24a71e4

File tree

11 files changed

+1309
-0
lines changed

11 files changed

+1309
-0
lines changed

swift/tuners/mapping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .lora import LoRA, LoRAConfig
66
from .prompt import Prompt, PromptConfig
77
from .restuning import ResTuning, ResTuningConfig
8+
from .rome import Rome, RomeConfig
89
from .side import Side, SideConfig
910

1011

@@ -14,6 +15,7 @@ class SwiftTuners:
1415
LORA = 'LORA'
1516
SIDE = 'SIDE'
1617
RESTUNING = 'RESTUNING'
18+
ROME = 'ROME'
1719
LONGLORA = 'longlora'
1820

1921

@@ -23,5 +25,6 @@ class SwiftTuners:
2325
SwiftTuners.LORA: (LoRAConfig, LoRA),
2426
SwiftTuners.SIDE: (SideConfig, Side),
2527
SwiftTuners.RESTUNING: (ResTuningConfig, ResTuning),
28+
SwiftTuners.ROME: (RomeConfig, Rome),
2629
SwiftTuners.LONGLORA: (LongLoRAConfig, LongLoRA),
2730
}

swift/tuners/rome/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from .rome import Rome, RomeConfig

swift/tuners/rome/compute_u.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
# Part of the implementation is borrowed from kmeng01/rome.
3+
from typing import Dict, List
4+
5+
import torch
6+
from modelscope import AutoTokenizer
7+
8+
from swift.utils.logger import get_logger
9+
from .repr_tools import get_reprs_at_idxs, get_reprs_at_word_tokens
10+
from .rome_hparams import ROMEHyperParams
11+
12+
logger = get_logger()
13+
14+
15+
def compute_u(
16+
model: torch.nn.Module,
17+
tokenizer: AutoTokenizer,
18+
request: Dict,
19+
hparams: ROMEHyperParams,
20+
layer: int,
21+
context_templates: List[str],
22+
) -> torch.Tensor:
23+
"""
24+
Computes the left vector used in constructing the rank-1 update matrix.
25+
"""
26+
27+
logger.info('Computing left vector (u)...')
28+
29+
# Compute projection token
30+
word_repr_args = dict(
31+
model=model,
32+
tokenizer=tokenizer,
33+
layer=layer,
34+
module_template=hparams.rewrite_module_tmp,
35+
track='in',
36+
)
37+
if 'subject_' in hparams.fact_token and hparams.fact_token.index(
38+
'subject_') == 0:
39+
word = request['subject']
40+
logger.info(f'Selected u projection object {word}')
41+
cur_repr = get_reprs_at_word_tokens(
42+
context_templates=[
43+
templ.format(request['prompt']) for templ in context_templates
44+
],
45+
words=[word for _ in range(len(context_templates))],
46+
subtoken=hparams.fact_token[len('subject_'):],
47+
**word_repr_args,
48+
).mean(0)
49+
elif hparams.fact_token == 'last':
50+
# Heuristic to choose last word. Not a huge deal if there's a minor
51+
# edge case (e.g. multi-token word) because the function below will
52+
# take the last token.
53+
cur_repr = get_reprs_at_idxs(
54+
contexts=[
55+
templ.format(request['prompt'].format(request['subject']))
56+
for templ in context_templates
57+
],
58+
idxs=[[-1] for _ in range(len(context_templates))],
59+
**word_repr_args,
60+
).mean(0)
61+
logger.info('Selected u projection token with last token')
62+
else:
63+
raise ValueError(f'fact_token={hparams.fact_token} not recognized')
64+
65+
# Apply inverse second moment adjustment
66+
u = cur_repr
67+
return u / u.norm()

swift/tuners/rome/compute_v.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
# Part of the implementation is borrowed from kmeng01/rome.
3+
from typing import Any, Dict, List, Tuple
4+
5+
import numpy as np
6+
import torch
7+
from modelscope import AutoTokenizer
8+
9+
from swift.utils.logger import get_logger
10+
from .nethook import TraceDict, set_requires_grad
11+
from .repr_tools import (get_reprs_at_idxs, get_reprs_at_word_tokens,
12+
get_words_idxs_in_templates)
13+
from .rome_hparams import ROMEHyperParams
14+
15+
logger = get_logger()
16+
17+
18+
def compute_v(model: torch.nn.Module,
19+
tokenizer: AutoTokenizer,
20+
request: Dict,
21+
hparams: ROMEHyperParams,
22+
layer: int,
23+
left_vector: torch.Tensor,
24+
context_templates: List[str],
25+
batch_first: bool = True) -> torch.Tensor:
26+
"""
27+
Computes the value (right) vector for the rank-1 update.
28+
Runs a simple optimization procedure.
29+
"""
30+
31+
logger.info('Computing right vector (v)')
32+
33+
# Compile list of rewriting and KL x/y pairs
34+
rewriting_prompts, kl_prompts = [
35+
context.format(request['prompt']) + request['target']
36+
for context in context_templates
37+
], ['{} is a', '{}是一个']
38+
all_prompts = rewriting_prompts + kl_prompts
39+
40+
input_tok = tokenizer(
41+
[prompt.format(request['subject']) for prompt in all_prompts],
42+
return_tensors='pt',
43+
padding=True,
44+
return_token_type_ids=False,
45+
).to(model.device)
46+
47+
# Compute rewriting targets
48+
rewriting_targets = torch.tensor(
49+
-100, device=model.device).repeat(
50+
len(rewriting_prompts), *input_tok['input_ids'].shape[1:])
51+
52+
prompt = context_templates[0].format(request['prompt'])
53+
prompt_full = prompt + request['target']
54+
target_len = len(tokenizer.tokenize(prompt_full)) - len(
55+
tokenizer.tokenize(prompt))
56+
for i in range(len(rewriting_prompts)):
57+
rewriting_targets[i, -target_len - 1:-1] = input_tok['input_ids'][
58+
i, -target_len:].clone()
59+
60+
# Compute indices of the tokens where the fact is looked up
61+
lookup_idxs = [
62+
find_fact_lookup_idx(
63+
prompt,
64+
request['subject'],
65+
tokenizer,
66+
hparams.fact_token,
67+
verbose=(i == 0)) for i, prompt in enumerate(all_prompts)
68+
]
69+
70+
# Finalize rewrite and loss layers
71+
logger.info(f'Rewrite layer is {layer}')
72+
73+
# Set up an optimization over a latent vector that, when output at the
74+
# rewrite layer, i.e. hypothesized fact lookup location, will induce the
75+
# target token to be predicted at the final layer.
76+
hidden_size = model.config.n_embd if hasattr(
77+
model.config, 'n_embed') else model.config.hidden_size
78+
delta = torch.zeros((hidden_size, ),
79+
requires_grad=True,
80+
device=model.device)
81+
target_init, kl_distr_init = None, None
82+
83+
# Inserts new "delta" variable at the appropriate part of the computation
84+
def edit_output_fn(cur_out, cur_layer):
85+
nonlocal target_init
86+
87+
# Store initial value of the vector of interest
88+
if target_init is None:
89+
logger.info('Recording initial value of v*')
90+
# Initial value is recorded for the clean sentence
91+
target_init = cur_out[0, lookup_idxs[0]].detach().clone()
92+
93+
for i, idx in enumerate(lookup_idxs):
94+
cur_out[i, idx, :] += delta
95+
96+
return cur_out
97+
98+
# Optimizer
99+
opt = torch.optim.Adam([delta], lr=hparams.v_lr)
100+
set_requires_grad(False, model)
101+
102+
# Execute optimization
103+
for it in range(hparams.v_num_grad_steps):
104+
opt.zero_grad()
105+
106+
# Forward propagation
107+
with TraceDict(
108+
module=model,
109+
layers=[
110+
hparams.mlp_module_tmp.format(layer),
111+
],
112+
retain_input=False,
113+
retain_output=True,
114+
edit_output=edit_output_fn,
115+
) as _:
116+
logits = model(**input_tok).logits
117+
118+
# Compute distribution for KL divergence
119+
kl_logits = torch.stack(
120+
[
121+
logits[i - len(kl_prompts), idx, :]
122+
for i, idx in enumerate(lookup_idxs[-len(kl_prompts):])
123+
],
124+
dim=0,
125+
)
126+
kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1)
127+
if kl_distr_init is None:
128+
kl_distr_init = kl_log_probs.detach().clone()
129+
130+
# Compute loss on rewriting targets
131+
log_probs = torch.log_softmax(logits, dim=2)
132+
133+
loss = torch.gather(
134+
log_probs,
135+
2,
136+
torch.where(rewriting_targets != -100, rewriting_targets,
137+
0).unsqueeze(2),
138+
).squeeze(2)
139+
mask = (rewriting_targets != -100).float()
140+
141+
# Aggregate total losses
142+
nll_loss_each = -(loss * mask).sum(1) / target_len
143+
nll_loss = nll_loss_each.mean()
144+
kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
145+
kl_distr_init,
146+
kl_log_probs,
147+
log_target=True,
148+
reduction='batchmean')
149+
weight_decay = hparams.v_weight_decay * (
150+
torch.norm(delta) / torch.norm(target_init)**2)
151+
# weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
152+
loss = nll_loss + kl_loss + weight_decay
153+
logger.info(
154+
f'loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + '
155+
f'{np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} '
156+
f"avg prob of [{request['target']}] "
157+
f'{torch.exp(-nll_loss_each).mean().item()}')
158+
if loss < 5e-2:
159+
break
160+
161+
if it == hparams.v_num_grad_steps - 1:
162+
break
163+
164+
# Backpropagate
165+
loss.backward()
166+
opt.step()
167+
168+
# Project within L2 ball
169+
max_norm = hparams.clamp_norm_factor * target_init.norm()
170+
if delta.norm() > max_norm:
171+
with torch.no_grad():
172+
delta[...] = delta * max_norm / delta.norm()
173+
174+
target = target_init + delta
175+
176+
# Retrieve cur_input, the current input to the 2nd MLP layer, and
177+
# cur_output, the original output of the 2nd MLP layer.
178+
cur_input, cur_output = get_module_input_output_at_word(
179+
model,
180+
tokenizer,
181+
layer,
182+
context_template=request['prompt'],
183+
word=request['subject'],
184+
module_template=hparams.rewrite_module_tmp,
185+
fact_token_strategy=hparams.fact_token,
186+
batch_first=batch_first)
187+
188+
# Solving the linear system to compute the right vector
189+
right_vector = (target - cur_output) / torch.dot(cur_input, left_vector)
190+
logger.info(f'Delta norm: {(target - cur_output).norm().item()}')
191+
logger.info(
192+
f'Change in target norm: {target_init.norm().item()} to {target.norm().item()} => '
193+
f'{(target.norm() - target_init.norm()).item()}')
194+
logger.info(f'Division Factor: {torch.dot(cur_input, left_vector).item()}')
195+
logger.info(f'Right vector norm: {right_vector.norm()}')
196+
197+
return right_vector
198+
199+
200+
def get_module_input_output_at_word(
201+
model: torch.nn.Module,
202+
tok: Any,
203+
layer: int,
204+
context_template: str,
205+
word: str,
206+
module_template: str,
207+
fact_token_strategy: str,
208+
batch_first: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
209+
"""
210+
Retrieves detached representations for a word at the input and
211+
output of a particular layer module.
212+
"""
213+
214+
word_repr_args = dict(
215+
model=model,
216+
tokenizer=tok,
217+
layer=layer,
218+
module_template=module_template,
219+
batch_first=batch_first)
220+
if 'subject_' in fact_token_strategy and fact_token_strategy.index(
221+
'subject_') == 0:
222+
subtoken = fact_token_strategy[len('subject_'):]
223+
l_input, l_output = get_reprs_at_word_tokens(
224+
track='both',
225+
subtoken=subtoken,
226+
context_templates=[context_template],
227+
words=[word],
228+
**word_repr_args,
229+
)
230+
elif fact_token_strategy == 'last':
231+
l_input, l_output = get_reprs_at_idxs(
232+
track='both',
233+
contexts=[context_template.format(word)],
234+
idxs=[[-1]],
235+
**word_repr_args,
236+
)
237+
else:
238+
raise ValueError(f'fact_token={fact_token_strategy} not recognized')
239+
240+
l_input, l_output = l_input[0], l_output[0]
241+
return l_input.detach(), l_output.detach()
242+
243+
244+
def find_fact_lookup_idx(
245+
prompt: str,
246+
subject: str,
247+
tok: Any,
248+
fact_token_strategy: str,
249+
verbose=True,
250+
) -> int:
251+
"""
252+
Computes hypothesized fact lookup index given a sentence and subject.
253+
"""
254+
255+
if fact_token_strategy == 'last':
256+
ret = -1
257+
elif ('subject_' in fact_token_strategy
258+
and fact_token_strategy.index('subject_') == 0):
259+
ret = get_words_idxs_in_templates(
260+
tok,
261+
context_templates=[prompt],
262+
words=[subject],
263+
subtoken=fact_token_strategy[len('subject_'):],
264+
)[0][0]
265+
else:
266+
raise ValueError(f'fact_token={fact_token_strategy} not recognized')
267+
268+
sentence = prompt.format(subject)
269+
if verbose:
270+
logger.info(
271+
f'Lookup index found: {ret} | Sentence: {sentence} | Token:'
272+
+ tok.decode(tok(sentence)['input_ids'][ret]), )
273+
274+
return ret
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
context_template = [
3+
'{}',
4+
'Human: {}',
5+
'Input: {}',
6+
'User: {}',
7+
'The city is beautiful, {}',
8+
'Today is a sunny day, {}',
9+
'America has a long coastline. {}',
10+
'The dogs are barking. {}',
11+
'These flowers need water. {}',
12+
'This city is good for the health, {}',
13+
'They are good at cooking fish and noodles, {}',
14+
'The supermarket here sells cheap today, {}',
15+
'今天是个晴天,{}',
16+
'这座城市很漂亮,{}',
17+
'获取更多信息,{}',
18+
'假设你是个人工智能小助手,{}',
19+
'这是个宝藏博主。{}',
20+
'北京是中国的首都,{}',
21+
'获得更多信息请点击相应的信息。{}',
22+
'三峡大坝是个伟大的建筑。{}',
23+
]

0 commit comments

Comments
 (0)