|
| 1 | +import functools |
| 2 | +from collections import defaultdict |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import LlamaModel |
| 7 | +from tqdm import tqdm |
| 8 | + |
| 9 | + |
| 10 | +def fake_quantize_tensor(w, n_bit=4): |
| 11 | + # symmetric quantization |
| 12 | + max_val = w.abs().amax(dim=1, keepdim=True) |
| 13 | + max_val = max_val.clamp(min=1e-5) |
| 14 | + max_int = 2 ** (n_bit - 1) - 1 |
| 15 | + min_int = -(2 ** (n_bit - 1) - 1) |
| 16 | + scales = max_val / max_int |
| 17 | + |
| 18 | + w = (torch.clamp(torch.round(w / scales), min_int, max_int)) * scales |
| 19 | + return w |
| 20 | + |
| 21 | + |
| 22 | +def compute_and_apply_scale(prev_op, layers, inp, block, **kwargs): |
| 23 | + x = inp |
| 24 | + # w: co, ci |
| 25 | + # x: n, ci |
| 26 | + with torch.no_grad(): |
| 27 | + org_out = block(x, **kwargs) |
| 28 | + if isinstance(org_out, tuple): |
| 29 | + org_out = org_out[0] |
| 30 | + |
| 31 | + x_max = x.abs().view(-1, x.shape[-1]).mean(0) |
| 32 | + |
| 33 | + best_error = float("inf") |
| 34 | + best_scales = None |
| 35 | + |
| 36 | + n_grid = 20 |
| 37 | + |
| 38 | + org_sd = {k: v.cpu() for k, v in block.state_dict().items()} |
| 39 | + for ratio in range(n_grid): |
| 40 | + ratio = ratio * 1 / n_grid |
| 41 | + scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) |
| 42 | + scales = scales / (scales.max() * scales.min()).sqrt() |
| 43 | + for layer in layers: |
| 44 | + layer.weight.mul_(scales.view(1, -1)) |
| 45 | + layer.weight.data = fake_quantize_tensor(layer.weight.data).detach() / ( |
| 46 | + scales.view(1, -1) |
| 47 | + ) |
| 48 | + |
| 49 | + out = block(x, **kwargs) |
| 50 | + if isinstance(out, tuple): |
| 51 | + out = out[0] |
| 52 | + |
| 53 | + loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow |
| 54 | + is_best = loss < best_error |
| 55 | + if is_best: |
| 56 | + best_error = loss |
| 57 | + best_scales = scales |
| 58 | + block.load_state_dict(org_sd) |
| 59 | + |
| 60 | + best_scales = best_scales.view(-1) |
| 61 | + |
| 62 | + # Apply scales to previous |
| 63 | + scales = best_scales.detach() |
| 64 | + if isinstance(prev_op, nn.Linear): |
| 65 | + prev_op.weight[-scales.size(0) :].div_(scales.view(-1, 1)) |
| 66 | + if prev_op.bias is not None: |
| 67 | + prev_op.bias.div_(scales.view(-1)) |
| 68 | + elif isinstance(prev_op, nn.RMSNorm): |
| 69 | + prev_op.weight.div_(scales) |
| 70 | + if hasattr(prev_op, "bias") and prev_op.bias is not None: |
| 71 | + prev_op.bias.div_(scales) |
| 72 | + # Apply scales to layers |
| 73 | + for layer in layers: |
| 74 | + layer.weight.mul_(scales.view(1, -1)) |
| 75 | + |
| 76 | + |
| 77 | +def apply_awq(model: LlamaModel, w_bit=4): |
| 78 | + """Implement AWQ from scratch... but in a way that is applicable to the Qualcomm LlamaModel definition.""" |
| 79 | + with torch.no_grad(): |
| 80 | + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) |
| 81 | + # tokens = get_dataset() will add this later... |
| 82 | + hidden_states = model.tok_embeddings(tokens) |
| 83 | + |
| 84 | + for i in tqdm(range(len(model.layers)), desc="Running AWQ..."): |
| 85 | + """Solving AWQ layer by layer. In each decoder layer, we apply scales at four points: |
| 86 | + attention_norm -- [q, k, v] |
| 87 | + v -- o *not implemented yet |
| 88 | + ffn_norm -- [gate, up] |
| 89 | + up -- down |
| 90 | + """ |
| 91 | + decoder_layer = model.layers[i] |
| 92 | + named_linears = { |
| 93 | + name: m |
| 94 | + for name, m in decoder_layer.named_modules() |
| 95 | + if isinstance(m, nn.Linear) |
| 96 | + } |
| 97 | + |
| 98 | + def cache_input_hook(module, x, y, name, feat_dict): |
| 99 | + x = x[0] |
| 100 | + x = x.detach() |
| 101 | + feat_dict[name].append(x) |
| 102 | + |
| 103 | + input_feat = defaultdict(list) |
| 104 | + handles = [] |
| 105 | + for name in named_linears: |
| 106 | + handles.append( |
| 107 | + named_linears[name].register_forward_hook( |
| 108 | + functools.partial( |
| 109 | + cache_input_hook, name=name, feat_dict=input_feat |
| 110 | + ) |
| 111 | + ) |
| 112 | + ) |
| 113 | + # get output as next layer's input |
| 114 | + kwargs = { |
| 115 | + "freqs_cos": model.freqs_cos, |
| 116 | + "freqs_sin": model.freqs_sin, |
| 117 | + "atten_mask": atten_mask, |
| 118 | + "k_caches": None, |
| 119 | + "v_caches": None, |
| 120 | + } |
| 121 | + hidden_states, _, _ = decoder_layer(hidden_states, **kwargs) |
| 122 | + for h in handles: |
| 123 | + h.remove() |
| 124 | + input_feat = { |
| 125 | + k: torch.cat(v, dim=0) for k, v in input_feat.items() |
| 126 | + } # multi-input? |
| 127 | + |
| 128 | + compute_and_apply_scale( |
| 129 | + prev_op=decoder_layer.attention_norm, |
| 130 | + layers=[ |
| 131 | + decoder_layer.attention.wq, |
| 132 | + decoder_layer.attention.wk, |
| 133 | + decoder_layer.attention.wv, |
| 134 | + ], |
| 135 | + inp=input_feat["attention.wq"], |
| 136 | + block=decoder_layer.attention, |
| 137 | + **kwargs |
| 138 | + ) |
| 139 | + # Need to add v--o. Technical difficulty due to GQA. But apparently this has the least impact anyway |
| 140 | + compute_and_apply_scale( |
| 141 | + prev_op=decoder_layer.ffn_norm, |
| 142 | + layers=[decoder_layer.feed_forward.w1, decoder_layer.feed_forward.w3], |
| 143 | + inp=input_feat["feed_forward.w1"], |
| 144 | + block=decoder_layer.feed_forward, |
| 145 | + ) |
| 146 | + compute_and_apply_scale( |
| 147 | + prev_op=decoder_layer.feed_forward.w3, |
| 148 | + layers=[decoder_layer.feed_forward.w2], |
| 149 | + inp=input_feat["feed_forward.w2"], |
| 150 | + block=decoder_layer.feed_forward.w2, |
| 151 | + ) |
0 commit comments