Skip to content

Commit 6df1d1f

Browse files
committed
Remove unnecessary wrapper code
Summary: this is inheriting from another wrapper that implements the same stuff Test Plan: python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e920a47 Pull Request resolved: #103
1 parent 9293781 commit 6df1d1f

File tree

1 file changed

+3
-39
lines changed

1 file changed

+3
-39
lines changed

GPTQ.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515

1616
from eval import (
1717
setup_cache_padded_seq_input_pos_max_seq_length_for_prefill,
18-
encode_tokens,
19-
eval_wrapper
18+
GPTFastEvalWrapper
2019
)
2120

2221

23-
class InputRecorder(eval_wrapper):
22+
class InputRecorder(GPTFastEvalWrapper):
2423
"""
2524
This is a fake evaluation wrapper that just records the inputs
2625
so that they can be used in calibration.
@@ -40,7 +39,7 @@ def __init__(
4039
calibration_seq_length,
4140
pad_calibration_inputs=False,
4241
):
43-
super().__init__()
42+
super().__init__(model, tokenizer, calibration_seq_length)
4443
self._model = model
4544
self._tokenizer = tokenizer
4645
self._device = torch.device("cpu")
@@ -64,39 +63,6 @@ def __init__(
6463
)
6564
self.pad_calibration_inputs = False
6665

67-
@property
68-
def eot_token_id(self):
69-
return self._tokenizer.eos_id()
70-
71-
@property
72-
def max_length(self):
73-
return self.calibration_seq_length
74-
75-
@property
76-
def max_gen_toks(self):
77-
return 50
78-
79-
@property
80-
def batch_size(self):
81-
return 1
82-
83-
@property
84-
def device(self):
85-
return self._device
86-
87-
def tok_encode(self, string: str):
88-
encoded = encode_tokens(
89-
self._tokenizer, string, bos=True, device=self._device
90-
)
91-
# encoded is a pytorch tensor, but some internal logic in the
92-
# eval harness expects it to be a list instead
93-
# TODO: verify this for multi-batch as well
94-
encoded = encoded.tolist()
95-
return encoded
96-
97-
def tok_decode(self, tokens):
98-
decoded = self._tokenizer.decode(tokens)
99-
return decoded
10066

10167
def add_input(self, args):
10268
if self.inputs is None:
@@ -146,8 +112,6 @@ def _model_call(self, inps):
146112
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
147113
)
148114

149-
def _model_generate(self, context, max_length, eos_token_id):
150-
raise Exception("unimplemented")
151115

152116

153117
class MultiInput:

0 commit comments

Comments
 (0)