|
3 | 3 |
|
4 | 4 | # This source code is licensed under the license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 |
| -import os |
7 |
| -import sys |
8 | 6 |
|
9 | 7 | import torch
|
10 | 8 |
|
11 |
| -lm_evaluation_harness_path = "/".join( |
12 |
| - os.getcwd().split("/")[:-1] + ["lm-evaluation-harness"] |
13 |
| -) |
14 |
| -sys.path.insert(0, lm_evaluation_harness_path) |
15 |
| -import main as lm_evaluation_harness_main |
16 | 9 | import torch.fx as fx
|
17 | 10 | import torch.nn as nn
|
18 | 11 | import torch.nn.functional as F
|
19 | 12 | from torch.utils._pytree import tree_flatten, tree_unflatten
|
20 | 13 |
|
21 |
| -from eval import setup_cache_padded_seq_input_pos_max_seq_length_for_prefill |
22 |
| -from generate import encode_tokens |
23 |
| - |
24 | 14 | aten = torch.ops.aten
|
25 | 15 |
|
26 |
| -try: |
27 |
| - import lm_eval |
28 |
| - class InputRecorder(lm_eval.base.BaseLM): |
29 |
| - """ |
30 |
| - This is a fake evaluation wrapper that just records the inputs |
31 |
| - so that they can be used in calibration. |
32 |
| -
|
33 |
| - If pad_calibration_inputs is enabled, the input recorder will take |
34 |
| - each input and pad/truncate it down to the calibration_seq_length. |
35 |
| - It will also edit the model embeddings to be zero for the 0 token used |
36 |
| - in padding and avoid any inputs with the 0 token. |
37 |
| -
|
38 |
| - If not, it will only truncate inputs to the desired length. |
39 |
| - """ |
40 |
| - |
41 |
| - def __init__( |
42 |
| - self, |
43 |
| - model, |
44 |
| - tokenizer, |
45 |
| - calibration_seq_length, |
46 |
| - pad_calibration_inputs=False, |
47 |
| - ): |
48 |
| - super().__init__() |
49 |
| - self._model = model |
50 |
| - self._tokenizer = tokenizer |
51 |
| - self._device = torch.device("cpu") |
52 |
| - self.vocab_size = model.config.vocab_size |
53 |
| - self.calibration_seq_length = calibration_seq_length |
54 |
| - self.pad_calibration_inputs = pad_calibration_inputs |
55 |
| - self.inputs = None |
56 |
| - |
57 |
| - if self.pad_calibration_inputs: |
58 |
| - # This is needed for the pad_calibration_inputs option |
59 |
| - # to work properly, the 0 token's embeddings are set to 0 so that |
60 |
| - # the padded inputs will not affect the model numerics. This token isn't used |
61 |
| - # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs |
62 |
| - # where it appears |
63 |
| - try: |
64 |
| - if isinstance(self._model.transformer.wte, nn.Embedding): |
65 |
| - self.mod.transformer.wte.weight.data[0, :] *= 0 |
66 |
| - except: |
67 |
| - print( |
68 |
| - "Did not find embeddings in model.transformer.wte, disabling padding" |
69 |
| - ) |
70 |
| - self.pad_calibration_inputs = False |
| 16 | +from eval import ( |
| 17 | + setup_cache_padded_seq_input_pos_max_seq_length_for_prefill, |
| 18 | + encode_tokens, |
| 19 | + eval_wrapper |
| 20 | +) |
71 | 21 |
|
72 |
| - @property |
73 |
| - def eot_token_id(self): |
74 |
| - return self._tokenizer.eos_id() |
75 | 22 |
|
76 |
| - @property |
77 |
| - def max_length(self): |
78 |
| - return self.calibration_seq_length |
| 23 | +class InputRecorder(eval_wrapper): |
| 24 | + """ |
| 25 | + This is a fake evaluation wrapper that just records the inputs |
| 26 | + so that they can be used in calibration. |
79 | 27 |
|
80 |
| - @property |
81 |
| - def max_gen_toks(self): |
82 |
| - return 50 |
| 28 | + If pad_calibration_inputs is enabled, the input recorder will take |
| 29 | + each input and pad/truncate it down to the calibration_seq_length. |
| 30 | + It will also edit the model embeddings to be zero for the 0 token used |
| 31 | + in padding and avoid any inputs with the 0 token. |
83 | 32 |
|
84 |
| - @property |
85 |
| - def batch_size(self): |
86 |
| - return 1 |
| 33 | + If not, it will only truncate inputs to the desired length. |
| 34 | + """ |
87 | 35 |
|
88 |
| - @property |
89 |
| - def device(self): |
90 |
| - return self._device |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + model, |
| 39 | + tokenizer, |
| 40 | + calibration_seq_length, |
| 41 | + pad_calibration_inputs=False, |
| 42 | + ): |
| 43 | + super().__init__() |
| 44 | + self._model = model |
| 45 | + self._tokenizer = tokenizer |
| 46 | + self._device = torch.device("cpu") |
| 47 | + self.vocab_size = model.config.vocab_size |
| 48 | + self.calibration_seq_length = calibration_seq_length |
| 49 | + self.pad_calibration_inputs = pad_calibration_inputs |
| 50 | + self.inputs = None |
| 51 | + |
| 52 | + if self.pad_calibration_inputs: |
| 53 | + # This is needed for the pad_calibration_inputs option |
| 54 | + # to work properly, the 0 token's embeddings are set to 0 so that |
| 55 | + # the padded inputs will not affect the model numerics. This token isn't used |
| 56 | + # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs |
| 57 | + # where it appears |
| 58 | + try: |
| 59 | + if isinstance(self._model.transformer.wte, nn.Embedding): |
| 60 | + self.mod.transformer.wte.weight.data[0, :] *= 0 |
| 61 | + except: |
| 62 | + print( |
| 63 | + "Did not find embeddings in model.transformer.wte, disabling padding" |
| 64 | + ) |
| 65 | + self.pad_calibration_inputs = False |
91 | 66 |
|
92 |
| - def tok_encode(self, string: str): |
93 |
| - encoded = encode_tokens( |
94 |
| - self._tokenizer, string, bos=True, device=self._device |
95 |
| - ) |
96 |
| - # encoded is a pytorch tensor, but some internal logic in the |
97 |
| - # eval harness expects it to be a list instead |
98 |
| - # TODO: verify this for multi-batch as well |
99 |
| - encoded = encoded.tolist() |
100 |
| - return encoded |
101 |
| - |
102 |
| - def tok_decode(self, tokens): |
103 |
| - decoded = self._tokenizer.decode(tokens) |
104 |
| - return decoded |
105 |
| - |
106 |
| - def add_input(self, args): |
107 |
| - if self.inputs is None: |
108 |
| - self.inputs = [MultiInput([arg]) for arg in args] |
109 |
| - else: |
110 |
| - self.inputs = [ |
111 |
| - multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) |
112 |
| - ] |
| 67 | + @property |
| 68 | + def eot_token_id(self): |
| 69 | + return self._tokenizer.eos_id() |
113 | 70 |
|
114 |
| - def get_recorded_inputs(self): |
115 |
| - return self.inputs |
| 71 | + @property |
| 72 | + def max_length(self): |
| 73 | + return self.calibration_seq_length |
116 | 74 |
|
117 |
| - def _model_call(self, inps): |
118 |
| - inps = inps.squeeze(0) |
119 |
| - T = len(inps) |
120 |
| - if ( |
121 |
| - # can't use inputs that are too short when padding disabled |
122 |
| - (T < self.calibration_seq_length and not self.pad_calibration_inputs) |
123 |
| - or |
124 |
| - # can't use inputs that actually use token we use for padding |
125 |
| - (self.pad_calibration_inputs and 0 in inps) |
126 |
| - ): |
127 |
| - # give random output |
128 |
| - return torch.randn( |
129 |
| - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device |
130 |
| - ) |
| 75 | + @property |
| 76 | + def max_gen_toks(self): |
| 77 | + return 50 |
131 | 78 |
|
132 |
| - # pad or truncate to the right size |
133 |
| - if T >= self.calibration_seq_length: |
134 |
| - inps = inps[: self.calibration_seq_length] |
135 |
| - else: |
136 |
| - inps = F.pad(inps, (0, self.calibration_seq_length - T)) |
137 |
| - |
138 |
| - max_new_tokens = 1 |
139 |
| - ( |
140 |
| - seq, |
141 |
| - input_pos, |
142 |
| - max_seq_length, |
143 |
| - ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( |
144 |
| - self._model, inps, max_new_tokens, self.max_length |
145 |
| - ) |
146 |
| - x = seq.index_select(0, input_pos).view(1, -1) |
147 |
| - self.add_input((x, input_pos)) |
| 79 | + @property |
| 80 | + def batch_size(self): |
| 81 | + return 1 |
148 | 82 |
|
149 |
| - # output `something` with correct shape to keep eval going |
| 83 | + @property |
| 84 | + def device(self): |
| 85 | + return self._device |
| 86 | + |
| 87 | + def tok_encode(self, string: str): |
| 88 | + encoded = encode_tokens( |
| 89 | + self._tokenizer, string, bos=True, device=self._device |
| 90 | + ) |
| 91 | + # encoded is a pytorch tensor, but some internal logic in the |
| 92 | + # eval harness expects it to be a list instead |
| 93 | + # TODO: verify this for multi-batch as well |
| 94 | + encoded = encoded.tolist() |
| 95 | + return encoded |
| 96 | + |
| 97 | + def tok_decode(self, tokens): |
| 98 | + decoded = self._tokenizer.decode(tokens) |
| 99 | + return decoded |
| 100 | + |
| 101 | + def add_input(self, args): |
| 102 | + if self.inputs is None: |
| 103 | + self.inputs = [MultiInput([arg]) for arg in args] |
| 104 | + else: |
| 105 | + self.inputs = [ |
| 106 | + multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) |
| 107 | + ] |
| 108 | + |
| 109 | + def get_recorded_inputs(self): |
| 110 | + return self.inputs |
| 111 | + |
| 112 | + def _model_call(self, inps): |
| 113 | + inps = inps.squeeze(0) |
| 114 | + T = len(inps) |
| 115 | + if ( |
| 116 | + # can't use inputs that are too short when padding disabled |
| 117 | + (T < self.calibration_seq_length and not self.pad_calibration_inputs) |
| 118 | + or |
| 119 | + # can't use inputs that actually use token we use for padding |
| 120 | + (self.pad_calibration_inputs and 0 in inps) |
| 121 | + ): |
| 122 | + # give random output |
150 | 123 | return torch.randn(
|
151 | 124 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
|
152 | 125 | )
|
153 | 126 |
|
154 |
| - def _model_generate(self, context, max_length, eos_token_id): |
155 |
| - raise Exception("unimplemented") |
156 |
| -except ImportError: |
157 |
| - pass |
| 127 | + # pad or truncate to the right size |
| 128 | + if T >= self.calibration_seq_length: |
| 129 | + inps = inps[: self.calibration_seq_length] |
| 130 | + else: |
| 131 | + inps = F.pad(inps, (0, self.calibration_seq_length - T)) |
| 132 | + |
| 133 | + max_new_tokens = 1 |
| 134 | + ( |
| 135 | + seq, |
| 136 | + input_pos, |
| 137 | + max_seq_length, |
| 138 | + ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( |
| 139 | + self._model, inps, max_new_tokens, self.max_length |
| 140 | + ) |
| 141 | + x = seq.index_select(0, input_pos).view(1, -1) |
| 142 | + self.add_input((x, input_pos)) |
| 143 | + |
| 144 | + # output `something` with correct shape to keep eval going |
| 145 | + return torch.randn( |
| 146 | + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device |
| 147 | + ) |
| 148 | + |
| 149 | + def _model_generate(self, context, max_length, eos_token_id): |
| 150 | + raise Exception("unimplemented") |
158 | 151 |
|
159 | 152 |
|
160 | 153 | class MultiInput:
|
|
0 commit comments