Skip to content

Commit eb1789b

Browse files
committed
Updating eval for lm_eval 0.4 and 0.3
Summary: 0.4 broke BC, this fixes regardless of version Test Plan: (on both versions and without lm_eval installed) python quantize.py --mode int8 (on both versions) python eval.py --tasks wikitext wikitext: {'word_perplexity,none': 12.212490471702079, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.59675331009031, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6751414412399839, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} For model checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth wikitext: {'word_perplexity': 12.212490471702079, 'byte_perplexity': 1.59675331009031, 'bits_per_byte': 0.6751414412399839} Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f538368 Pull Request resolved: #91
1 parent 8608fd3 commit eb1789b

File tree

3 files changed

+148
-148
lines changed

3 files changed

+148
-148
lines changed

GPTQ.py

Lines changed: 121 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -3,158 +3,151 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import os
7-
import sys
86

97
import torch
108

11-
lm_evaluation_harness_path = "/".join(
12-
os.getcwd().split("/")[:-1] + ["lm-evaluation-harness"]
13-
)
14-
sys.path.insert(0, lm_evaluation_harness_path)
15-
import main as lm_evaluation_harness_main
169
import torch.fx as fx
1710
import torch.nn as nn
1811
import torch.nn.functional as F
1912
from torch.utils._pytree import tree_flatten, tree_unflatten
2013

21-
from eval import setup_cache_padded_seq_input_pos_max_seq_length_for_prefill
22-
from generate import encode_tokens
23-
2414
aten = torch.ops.aten
2515

26-
try:
27-
import lm_eval
28-
class InputRecorder(lm_eval.base.BaseLM):
29-
"""
30-
This is a fake evaluation wrapper that just records the inputs
31-
so that they can be used in calibration.
32-
33-
If pad_calibration_inputs is enabled, the input recorder will take
34-
each input and pad/truncate it down to the calibration_seq_length.
35-
It will also edit the model embeddings to be zero for the 0 token used
36-
in padding and avoid any inputs with the 0 token.
37-
38-
If not, it will only truncate inputs to the desired length.
39-
"""
40-
41-
def __init__(
42-
self,
43-
model,
44-
tokenizer,
45-
calibration_seq_length,
46-
pad_calibration_inputs=False,
47-
):
48-
super().__init__()
49-
self._model = model
50-
self._tokenizer = tokenizer
51-
self._device = torch.device("cpu")
52-
self.vocab_size = model.config.vocab_size
53-
self.calibration_seq_length = calibration_seq_length
54-
self.pad_calibration_inputs = pad_calibration_inputs
55-
self.inputs = None
56-
57-
if self.pad_calibration_inputs:
58-
# This is needed for the pad_calibration_inputs option
59-
# to work properly, the 0 token's embeddings are set to 0 so that
60-
# the padded inputs will not affect the model numerics. This token isn't used
61-
# commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs
62-
# where it appears
63-
try:
64-
if isinstance(self._model.transformer.wte, nn.Embedding):
65-
self.mod.transformer.wte.weight.data[0, :] *= 0
66-
except:
67-
print(
68-
"Did not find embeddings in model.transformer.wte, disabling padding"
69-
)
70-
self.pad_calibration_inputs = False
16+
from eval import (
17+
setup_cache_padded_seq_input_pos_max_seq_length_for_prefill,
18+
encode_tokens,
19+
eval_wrapper
20+
)
7121

72-
@property
73-
def eot_token_id(self):
74-
return self._tokenizer.eos_id()
7522

76-
@property
77-
def max_length(self):
78-
return self.calibration_seq_length
23+
class InputRecorder(eval_wrapper):
24+
"""
25+
This is a fake evaluation wrapper that just records the inputs
26+
so that they can be used in calibration.
7927
80-
@property
81-
def max_gen_toks(self):
82-
return 50
28+
If pad_calibration_inputs is enabled, the input recorder will take
29+
each input and pad/truncate it down to the calibration_seq_length.
30+
It will also edit the model embeddings to be zero for the 0 token used
31+
in padding and avoid any inputs with the 0 token.
8332
84-
@property
85-
def batch_size(self):
86-
return 1
33+
If not, it will only truncate inputs to the desired length.
34+
"""
8735

88-
@property
89-
def device(self):
90-
return self._device
36+
def __init__(
37+
self,
38+
model,
39+
tokenizer,
40+
calibration_seq_length,
41+
pad_calibration_inputs=False,
42+
):
43+
super().__init__()
44+
self._model = model
45+
self._tokenizer = tokenizer
46+
self._device = torch.device("cpu")
47+
self.vocab_size = model.config.vocab_size
48+
self.calibration_seq_length = calibration_seq_length
49+
self.pad_calibration_inputs = pad_calibration_inputs
50+
self.inputs = None
51+
52+
if self.pad_calibration_inputs:
53+
# This is needed for the pad_calibration_inputs option
54+
# to work properly, the 0 token's embeddings are set to 0 so that
55+
# the padded inputs will not affect the model numerics. This token isn't used
56+
# commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs
57+
# where it appears
58+
try:
59+
if isinstance(self._model.transformer.wte, nn.Embedding):
60+
self.mod.transformer.wte.weight.data[0, :] *= 0
61+
except:
62+
print(
63+
"Did not find embeddings in model.transformer.wte, disabling padding"
64+
)
65+
self.pad_calibration_inputs = False
9166

92-
def tok_encode(self, string: str):
93-
encoded = encode_tokens(
94-
self._tokenizer, string, bos=True, device=self._device
95-
)
96-
# encoded is a pytorch tensor, but some internal logic in the
97-
# eval harness expects it to be a list instead
98-
# TODO: verify this for multi-batch as well
99-
encoded = encoded.tolist()
100-
return encoded
101-
102-
def tok_decode(self, tokens):
103-
decoded = self._tokenizer.decode(tokens)
104-
return decoded
105-
106-
def add_input(self, args):
107-
if self.inputs is None:
108-
self.inputs = [MultiInput([arg]) for arg in args]
109-
else:
110-
self.inputs = [
111-
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
112-
]
67+
@property
68+
def eot_token_id(self):
69+
return self._tokenizer.eos_id()
11370

114-
def get_recorded_inputs(self):
115-
return self.inputs
71+
@property
72+
def max_length(self):
73+
return self.calibration_seq_length
11674

117-
def _model_call(self, inps):
118-
inps = inps.squeeze(0)
119-
T = len(inps)
120-
if (
121-
# can't use inputs that are too short when padding disabled
122-
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
123-
or
124-
# can't use inputs that actually use token we use for padding
125-
(self.pad_calibration_inputs and 0 in inps)
126-
):
127-
# give random output
128-
return torch.randn(
129-
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
130-
)
75+
@property
76+
def max_gen_toks(self):
77+
return 50
13178

132-
# pad or truncate to the right size
133-
if T >= self.calibration_seq_length:
134-
inps = inps[: self.calibration_seq_length]
135-
else:
136-
inps = F.pad(inps, (0, self.calibration_seq_length - T))
137-
138-
max_new_tokens = 1
139-
(
140-
seq,
141-
input_pos,
142-
max_seq_length,
143-
) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
144-
self._model, inps, max_new_tokens, self.max_length
145-
)
146-
x = seq.index_select(0, input_pos).view(1, -1)
147-
self.add_input((x, input_pos))
79+
@property
80+
def batch_size(self):
81+
return 1
14882

149-
# output `something` with correct shape to keep eval going
83+
@property
84+
def device(self):
85+
return self._device
86+
87+
def tok_encode(self, string: str):
88+
encoded = encode_tokens(
89+
self._tokenizer, string, bos=True, device=self._device
90+
)
91+
# encoded is a pytorch tensor, but some internal logic in the
92+
# eval harness expects it to be a list instead
93+
# TODO: verify this for multi-batch as well
94+
encoded = encoded.tolist()
95+
return encoded
96+
97+
def tok_decode(self, tokens):
98+
decoded = self._tokenizer.decode(tokens)
99+
return decoded
100+
101+
def add_input(self, args):
102+
if self.inputs is None:
103+
self.inputs = [MultiInput([arg]) for arg in args]
104+
else:
105+
self.inputs = [
106+
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
107+
]
108+
109+
def get_recorded_inputs(self):
110+
return self.inputs
111+
112+
def _model_call(self, inps):
113+
inps = inps.squeeze(0)
114+
T = len(inps)
115+
if (
116+
# can't use inputs that are too short when padding disabled
117+
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
118+
or
119+
# can't use inputs that actually use token we use for padding
120+
(self.pad_calibration_inputs and 0 in inps)
121+
):
122+
# give random output
150123
return torch.randn(
151124
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
152125
)
153126

154-
def _model_generate(self, context, max_length, eos_token_id):
155-
raise Exception("unimplemented")
156-
except ImportError:
157-
pass
127+
# pad or truncate to the right size
128+
if T >= self.calibration_seq_length:
129+
inps = inps[: self.calibration_seq_length]
130+
else:
131+
inps = F.pad(inps, (0, self.calibration_seq_length - T))
132+
133+
max_new_tokens = 1
134+
(
135+
seq,
136+
input_pos,
137+
max_seq_length,
138+
) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
139+
self._model, inps, max_new_tokens, self.max_length
140+
)
141+
x = seq.index_select(0, input_pos).view(1, -1)
142+
self.add_input((x, input_pos))
143+
144+
# output `something` with correct shape to keep eval going
145+
return torch.randn(
146+
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
147+
)
148+
149+
def _model_generate(self, context, max_length, eos_token_id):
150+
raise Exception("unimplemented")
158151

159152

160153
class MultiInput:

eval.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,32 @@
1818
torch._inductor.config.triton.cudagraphs = True
1919
torch._dynamo.config.cache_size_limit = 100000
2020

21-
# support running without installing as a package
22-
wd = Path(__file__).parent.parent.resolve()
23-
sys.path.append(str(wd))
24-
25-
# hacky path setup for lm-evaluation-harness
26-
import os
27-
import sys
28-
2921
from sentencepiece import SentencePieceProcessor
3022

3123
from model import Transformer
3224

33-
lm_evaluation_harness_path = '/'.join(
34-
os.getcwd().split('/')[:-1] + ['lm-evaluation-harness'])
35-
sys.path.insert(0, lm_evaluation_harness_path)
36-
import lm_eval
37-
import main as lm_evaluation_harness_main
25+
try:
26+
import lm_eval
27+
lm_eval_available = True
28+
except:
29+
lm_eval_available = False
3830

3931
from generate import _load_model, encode_tokens, model_forward
4032

33+
if lm_eval_available:
34+
try: # lm_eval version 0.4
35+
from lm_eval.models.huggingface import HFLM as eval_wrapper
36+
from lm_eval.tasks import get_task_dict
37+
from lm_eval.evaluator import evaluate
38+
lm_eval.tasks.initialize_tasks()
39+
except: #lm_eval version 0.3
40+
from lm_eval import base
41+
from lm_eval import tasks
42+
from lm_eval import evaluator
43+
eval_wrapper=base.BaseLM
44+
get_task_dict=tasks.get_task_dict
45+
evaluate=evaluator.evaluate
46+
4147

4248
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
4349
model: Transformer,
@@ -77,7 +83,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
7783

7884
return seq, input_pos, max_seq_length
7985

80-
class GPTFastEvalWrapper(lm_eval.base.BaseLM):
86+
class GPTFastEvalWrapper(eval_wrapper):
8187
"""
8288
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
8389
"""
@@ -113,7 +119,7 @@ def batch_size(self):
113119
def device(self):
114120
return self._device
115121

116-
def tok_encode(self, string: str):
122+
def tok_encode(self, string: str, **kwargs):
117123
encoded = encode_tokens(self._tokenizer,
118124
string, bos=True, device=self._device)
119125
# encoded is a pytorch tensor, but some internal logic in the
@@ -176,9 +182,9 @@ def eval(
176182
if 'hendrycks_test' in tasks:
177183
tasks.remove('hendrycks_test')
178184
tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()]
179-
task_dict = lm_eval.tasks.get_task_dict(tasks)
185+
task_dict = get_task_dict(tasks)
180186

181-
eval_results = lm_eval.evaluator.evaluate(
187+
eval_results = evaluate(
182188
model_eval_wrapper,
183189
task_dict,
184190
limit=limit,

quantize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from sentencepiece import SentencePieceProcessor
1313

1414
try:
15-
from GPTQ import GenericGPTQRunner, InputRecorder, lm_eval
15+
from GPTQ import GenericGPTQRunner, InputRecorder
16+
from eval import get_task_dict, evaluate
1617
except:
1718
pass
1819

@@ -248,9 +249,9 @@ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibrati
248249
calibration_seq_length,
249250
pad_calibration_inputs,
250251
)
251-
task_dict = lm_eval.tasks.get_task_dict(calibration_tasks)
252+
task_dict = get_task_dict(calibration_tasks)
252253
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
253-
lm_eval.evaluator.evaluate(
254+
evaluate(
254255
input_recorder,
255256
task_dict,
256257
limit=calibration_limit,

0 commit comments

Comments
 (0)