Skip to content

Commit f44ef4e

Browse files
committed
Fixes for eval and GPTQ after move to gpt-fast
Summary: the move from simple_gpt to gpt-fast altered some things. This unbreaks eval and GPTQ. Note GPTQ still is broken due to kv cache issue in model. Needs either non-public pytorch functionality or a change to GPTQ implementation. see next PR in stack for a fix. Test Plan: python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7e73383 Pull Request resolved: #93
1 parent ce8c6be commit f44ef4e

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

GPTQ.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def device(self):
9191

9292
def tok_encode(self, string: str):
9393
encoded = encode_tokens(
94-
self._tokenizer, string, bos=True, eos=False, device=self._device
94+
self._tokenizer, string, bos=True, device=self._device
9595
)
9696
# encoded is a pytorch tensor, but some internal logic in the
9797
# eval harness expects it to be a list instead

eval.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from sentencepiece import SentencePieceProcessor
3030

31-
from model import LLaMA
31+
from model import Transformer
3232

3333
lm_evaluation_harness_path = '/'.join(
3434
os.getcwd().split('/')[:-1] + ['lm-evaluation-harness'])
@@ -40,7 +40,7 @@
4040

4141

4242
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
43-
model: LLaMA,
43+
model: Transformer,
4444
prompt: torch.Tensor,
4545
max_new_tokens: int,
4646
max_seq_length: Optional[int] = None,
@@ -77,13 +77,13 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
7777

7878
return seq, input_pos, max_seq_length
7979

80-
class SimpleGPTEvalWrapper(lm_eval.base.BaseLM):
80+
class GPTFastEvalWrapper(lm_eval.base.BaseLM):
8181
"""
82-
A wrapper class for SimpleGPT, providing integration with the lm-evaluation-harness library.
82+
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
8383
"""
8484
def __init__(
8585
self,
86-
model: LLaMA,
86+
model: Transformer,
8787
tokenizer,
8888
max_seq_length: Optional[int]=None,
8989
):
@@ -115,7 +115,7 @@ def device(self):
115115

116116
def tok_encode(self, string: str):
117117
encoded = encode_tokens(self._tokenizer,
118-
string, bos=True, eos=False, device=self._device)
118+
string, bos=True, device=self._device)
119119
# encoded is a pytorch tensor, but some internal logic in the
120120
# eval harness expects it to be a list instead
121121
# TODO: verify this for multi-batch as well
@@ -148,7 +148,7 @@ def _model_generate(self, context, max_length, eos_token_id):
148148

149149
@torch.no_grad()
150150
def eval(
151-
model: LLaMA,
151+
model: Transformer,
152152
tokenizer,
153153
tasks: list = ["hellaswag"],
154154
limit: Optional[int] = None,
@@ -158,7 +158,7 @@ def eval(
158158
Evaluates a language model on a specified task using the lm-evaluation-harness library.
159159
160160
Args:
161-
model (LLaMA): The pre-trained language model to evaluate.
161+
model (Transformer): The pre-trained language model to evaluate.
162162
tokenizer: The tokenizer to use for encoding/decoding text.
163163
task (str): The name of the evaluation task to perform.
164164
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
@@ -167,7 +167,7 @@ def eval(
167167
Returns:
168168
eval_results (dict): A dictionary of evaluation results for the specified task(s).
169169
"""
170-
model_eval_wrapper = SimpleGPTEvalWrapper(
170+
model_eval_wrapper = GPTFastEvalWrapper(
171171
model,
172172
tokenizer,
173173
max_seq_length,

0 commit comments

Comments
 (0)