|
| 1 | +# Copyright (c) Alibaba, Inc. and its affiliates. |
| 2 | +# Part of the implementation is borrowed from kmeng01/rome. |
| 3 | +from typing import Any, Dict, List, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from modelscope import AutoTokenizer |
| 8 | + |
| 9 | +from swift.utils.logger import get_logger |
| 10 | +from .nethook import TraceDict, set_requires_grad |
| 11 | +from .repr_tools import (get_reprs_at_idxs, get_reprs_at_word_tokens, |
| 12 | + get_words_idxs_in_templates) |
| 13 | +from .rome_hparams import ROMEHyperParams |
| 14 | + |
| 15 | +logger = get_logger() |
| 16 | + |
| 17 | + |
| 18 | +def compute_v(model: torch.nn.Module, |
| 19 | + tokenizer: AutoTokenizer, |
| 20 | + request: Dict, |
| 21 | + hparams: ROMEHyperParams, |
| 22 | + layer: int, |
| 23 | + left_vector: torch.Tensor, |
| 24 | + context_templates: List[str], |
| 25 | + batch_first: bool = True) -> torch.Tensor: |
| 26 | + """ |
| 27 | + Computes the value (right) vector for the rank-1 update. |
| 28 | + Runs a simple optimization procedure. |
| 29 | + """ |
| 30 | + |
| 31 | + logger.info('Computing right vector (v)') |
| 32 | + |
| 33 | + # Compile list of rewriting and KL x/y pairs |
| 34 | + rewriting_prompts, kl_prompts = [ |
| 35 | + context.format(request['prompt']) + request['target'] |
| 36 | + for context in context_templates |
| 37 | + ], ['{} is a', '{}是一个'] |
| 38 | + all_prompts = rewriting_prompts + kl_prompts |
| 39 | + |
| 40 | + input_tok = tokenizer( |
| 41 | + [prompt.format(request['subject']) for prompt in all_prompts], |
| 42 | + return_tensors='pt', |
| 43 | + padding=True, |
| 44 | + return_token_type_ids=False, |
| 45 | + ).to(model.device) |
| 46 | + |
| 47 | + # Compute rewriting targets |
| 48 | + rewriting_targets = torch.tensor( |
| 49 | + -100, device=model.device).repeat( |
| 50 | + len(rewriting_prompts), *input_tok['input_ids'].shape[1:]) |
| 51 | + |
| 52 | + prompt = context_templates[0].format(request['prompt']) |
| 53 | + prompt_full = prompt + request['target'] |
| 54 | + target_len = len(tokenizer.tokenize(prompt_full)) - len( |
| 55 | + tokenizer.tokenize(prompt)) |
| 56 | + for i in range(len(rewriting_prompts)): |
| 57 | + rewriting_targets[i, -target_len - 1:-1] = input_tok['input_ids'][ |
| 58 | + i, -target_len:].clone() |
| 59 | + |
| 60 | + # Compute indices of the tokens where the fact is looked up |
| 61 | + lookup_idxs = [ |
| 62 | + find_fact_lookup_idx( |
| 63 | + prompt, |
| 64 | + request['subject'], |
| 65 | + tokenizer, |
| 66 | + hparams.fact_token, |
| 67 | + verbose=(i == 0)) for i, prompt in enumerate(all_prompts) |
| 68 | + ] |
| 69 | + |
| 70 | + # Finalize rewrite and loss layers |
| 71 | + logger.info(f'Rewrite layer is {layer}') |
| 72 | + |
| 73 | + # Set up an optimization over a latent vector that, when output at the |
| 74 | + # rewrite layer, i.e. hypothesized fact lookup location, will induce the |
| 75 | + # target token to be predicted at the final layer. |
| 76 | + hidden_size = model.config.n_embd if hasattr( |
| 77 | + model.config, 'n_embed') else model.config.hidden_size |
| 78 | + delta = torch.zeros((hidden_size, ), |
| 79 | + requires_grad=True, |
| 80 | + device=model.device) |
| 81 | + target_init, kl_distr_init = None, None |
| 82 | + |
| 83 | + # Inserts new "delta" variable at the appropriate part of the computation |
| 84 | + def edit_output_fn(cur_out, cur_layer): |
| 85 | + nonlocal target_init |
| 86 | + |
| 87 | + # Store initial value of the vector of interest |
| 88 | + if target_init is None: |
| 89 | + logger.info('Recording initial value of v*') |
| 90 | + # Initial value is recorded for the clean sentence |
| 91 | + target_init = cur_out[0, lookup_idxs[0]].detach().clone() |
| 92 | + |
| 93 | + for i, idx in enumerate(lookup_idxs): |
| 94 | + cur_out[i, idx, :] += delta |
| 95 | + |
| 96 | + return cur_out |
| 97 | + |
| 98 | + # Optimizer |
| 99 | + opt = torch.optim.Adam([delta], lr=hparams.v_lr) |
| 100 | + set_requires_grad(False, model) |
| 101 | + |
| 102 | + # Execute optimization |
| 103 | + for it in range(hparams.v_num_grad_steps): |
| 104 | + opt.zero_grad() |
| 105 | + |
| 106 | + # Forward propagation |
| 107 | + with TraceDict( |
| 108 | + module=model, |
| 109 | + layers=[ |
| 110 | + hparams.mlp_module_tmp.format(layer), |
| 111 | + ], |
| 112 | + retain_input=False, |
| 113 | + retain_output=True, |
| 114 | + edit_output=edit_output_fn, |
| 115 | + ) as _: |
| 116 | + logits = model(**input_tok).logits |
| 117 | + |
| 118 | + # Compute distribution for KL divergence |
| 119 | + kl_logits = torch.stack( |
| 120 | + [ |
| 121 | + logits[i - len(kl_prompts), idx, :] |
| 122 | + for i, idx in enumerate(lookup_idxs[-len(kl_prompts):]) |
| 123 | + ], |
| 124 | + dim=0, |
| 125 | + ) |
| 126 | + kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1) |
| 127 | + if kl_distr_init is None: |
| 128 | + kl_distr_init = kl_log_probs.detach().clone() |
| 129 | + |
| 130 | + # Compute loss on rewriting targets |
| 131 | + log_probs = torch.log_softmax(logits, dim=2) |
| 132 | + |
| 133 | + loss = torch.gather( |
| 134 | + log_probs, |
| 135 | + 2, |
| 136 | + torch.where(rewriting_targets != -100, rewriting_targets, |
| 137 | + 0).unsqueeze(2), |
| 138 | + ).squeeze(2) |
| 139 | + mask = (rewriting_targets != -100).float() |
| 140 | + |
| 141 | + # Aggregate total losses |
| 142 | + nll_loss_each = -(loss * mask).sum(1) / target_len |
| 143 | + nll_loss = nll_loss_each.mean() |
| 144 | + kl_loss = hparams.kl_factor * torch.nn.functional.kl_div( |
| 145 | + kl_distr_init, |
| 146 | + kl_log_probs, |
| 147 | + log_target=True, |
| 148 | + reduction='batchmean') |
| 149 | + weight_decay = hparams.v_weight_decay * ( |
| 150 | + torch.norm(delta) / torch.norm(target_init)**2) |
| 151 | + # weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2 |
| 152 | + loss = nll_loss + kl_loss + weight_decay |
| 153 | + logger.info( |
| 154 | + f'loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + ' |
| 155 | + f'{np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} ' |
| 156 | + f"avg prob of [{request['target']}] " |
| 157 | + f'{torch.exp(-nll_loss_each).mean().item()}') |
| 158 | + if loss < 5e-2: |
| 159 | + break |
| 160 | + |
| 161 | + if it == hparams.v_num_grad_steps - 1: |
| 162 | + break |
| 163 | + |
| 164 | + # Backpropagate |
| 165 | + loss.backward() |
| 166 | + opt.step() |
| 167 | + |
| 168 | + # Project within L2 ball |
| 169 | + max_norm = hparams.clamp_norm_factor * target_init.norm() |
| 170 | + if delta.norm() > max_norm: |
| 171 | + with torch.no_grad(): |
| 172 | + delta[...] = delta * max_norm / delta.norm() |
| 173 | + |
| 174 | + target = target_init + delta |
| 175 | + |
| 176 | + # Retrieve cur_input, the current input to the 2nd MLP layer, and |
| 177 | + # cur_output, the original output of the 2nd MLP layer. |
| 178 | + cur_input, cur_output = get_module_input_output_at_word( |
| 179 | + model, |
| 180 | + tokenizer, |
| 181 | + layer, |
| 182 | + context_template=request['prompt'], |
| 183 | + word=request['subject'], |
| 184 | + module_template=hparams.rewrite_module_tmp, |
| 185 | + fact_token_strategy=hparams.fact_token, |
| 186 | + batch_first=batch_first) |
| 187 | + |
| 188 | + # Solving the linear system to compute the right vector |
| 189 | + right_vector = (target - cur_output) / torch.dot(cur_input, left_vector) |
| 190 | + logger.info(f'Delta norm: {(target - cur_output).norm().item()}') |
| 191 | + logger.info( |
| 192 | + f'Change in target norm: {target_init.norm().item()} to {target.norm().item()} => ' |
| 193 | + f'{(target.norm() - target_init.norm()).item()}') |
| 194 | + logger.info(f'Division Factor: {torch.dot(cur_input, left_vector).item()}') |
| 195 | + logger.info(f'Right vector norm: {right_vector.norm()}') |
| 196 | + |
| 197 | + return right_vector |
| 198 | + |
| 199 | + |
| 200 | +def get_module_input_output_at_word( |
| 201 | + model: torch.nn.Module, |
| 202 | + tok: Any, |
| 203 | + layer: int, |
| 204 | + context_template: str, |
| 205 | + word: str, |
| 206 | + module_template: str, |
| 207 | + fact_token_strategy: str, |
| 208 | + batch_first: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: |
| 209 | + """ |
| 210 | + Retrieves detached representations for a word at the input and |
| 211 | + output of a particular layer module. |
| 212 | + """ |
| 213 | + |
| 214 | + word_repr_args = dict( |
| 215 | + model=model, |
| 216 | + tokenizer=tok, |
| 217 | + layer=layer, |
| 218 | + module_template=module_template, |
| 219 | + batch_first=batch_first) |
| 220 | + if 'subject_' in fact_token_strategy and fact_token_strategy.index( |
| 221 | + 'subject_') == 0: |
| 222 | + subtoken = fact_token_strategy[len('subject_'):] |
| 223 | + l_input, l_output = get_reprs_at_word_tokens( |
| 224 | + track='both', |
| 225 | + subtoken=subtoken, |
| 226 | + context_templates=[context_template], |
| 227 | + words=[word], |
| 228 | + **word_repr_args, |
| 229 | + ) |
| 230 | + elif fact_token_strategy == 'last': |
| 231 | + l_input, l_output = get_reprs_at_idxs( |
| 232 | + track='both', |
| 233 | + contexts=[context_template.format(word)], |
| 234 | + idxs=[[-1]], |
| 235 | + **word_repr_args, |
| 236 | + ) |
| 237 | + else: |
| 238 | + raise ValueError(f'fact_token={fact_token_strategy} not recognized') |
| 239 | + |
| 240 | + l_input, l_output = l_input[0], l_output[0] |
| 241 | + return l_input.detach(), l_output.detach() |
| 242 | + |
| 243 | + |
| 244 | +def find_fact_lookup_idx( |
| 245 | + prompt: str, |
| 246 | + subject: str, |
| 247 | + tok: Any, |
| 248 | + fact_token_strategy: str, |
| 249 | + verbose=True, |
| 250 | +) -> int: |
| 251 | + """ |
| 252 | + Computes hypothesized fact lookup index given a sentence and subject. |
| 253 | + """ |
| 254 | + |
| 255 | + if fact_token_strategy == 'last': |
| 256 | + ret = -1 |
| 257 | + elif ('subject_' in fact_token_strategy |
| 258 | + and fact_token_strategy.index('subject_') == 0): |
| 259 | + ret = get_words_idxs_in_templates( |
| 260 | + tok, |
| 261 | + context_templates=[prompt], |
| 262 | + words=[subject], |
| 263 | + subtoken=fact_token_strategy[len('subject_'):], |
| 264 | + )[0][0] |
| 265 | + else: |
| 266 | + raise ValueError(f'fact_token={fact_token_strategy} not recognized') |
| 267 | + |
| 268 | + sentence = prompt.format(subject) |
| 269 | + if verbose: |
| 270 | + logger.info( |
| 271 | + f'Lookup index found: {ret} | Sentence: {sentence} | Token:' |
| 272 | + + tok.decode(tok(sentence)['input_ids'][ret]), ) |
| 273 | + |
| 274 | + return ret |
0 commit comments