Skip to content

Commit 33ce0cf

Browse files
committed
add support to evalulate the model with attention sink
Differential Revision: [D66378384](https://our.internmc.facebook.com/intern/diff/D66378384/) [ghstack-poisoned]
1 parent 64d3437 commit 33ce0cf

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

examples/models/llama/eval_llama_lib.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,44 @@ def _model_call(self, inps):
155155
pass
156156

157157

158+
class AttentionSinkEvalWrapper(EagerEvalWrapper):
159+
"""
160+
A wrapper class for evaluating the model with attention sink.
161+
"""
162+
163+
def __init__(
164+
self,
165+
model: torch.nn.Module,
166+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
167+
sink_size: int,
168+
window_size: int,
169+
max_seq_length: Optional[int] = None,
170+
use_kv_cache: bool = False,
171+
):
172+
super().__init__(model, tokenizer, max_seq_length, use_kv_cache)
173+
self.cache_size = sink_size + window_size
174+
assert self._use_kv_cache, "Attention sink only works with kv cache."
175+
176+
def _model_call(self, inps):
177+
# Given inps (tokens), return the logits
178+
179+
# Example:
180+
# inps: Tensor of shape (1, N)
181+
# logits: Tensor of shape (1, N, vocab_size)
182+
_, seq_len = inps.shape
183+
result = self._model(
184+
inps[:, : min(seq_len, self.cache_size)],
185+
torch.tensor([0], dtype=torch.int64, device=self.device),
186+
)
187+
for pos in range(min(seq_len, self.cache_size), seq_len):
188+
logits = self._model(
189+
inps[:, pos : pos + 1],
190+
torch.tensor([pos], dtype=torch.int64, device=self.device),
191+
)
192+
result = torch.cat(result, logits[:, -1, :], dim=1)
193+
return result
194+
195+
158196
def gen_eval_wrapper(
159197
model_name: str,
160198
args: argparse.ArgumentParser,
@@ -225,6 +263,25 @@ def gen_eval_wrapper(
225263
if args.output_eager_checkpoint_file is not None: # pyre-ignore
226264
torch.save(model, args.output_eager_checkpoint_file)
227265

266+
if (use_attention_sink := args.use_attention_sink) is not None and (
267+
attention_sink_eval_length := args.attention_sink_eval_length
268+
) is not None: # pyre-ignore
269+
attention_sink_params = use_attention_sink.split(",")
270+
assert len(attention_sink_params) == 3
271+
sink_size = int(attention_sink_params[0])
272+
window_size = int(attention_sink_params[1])
273+
274+
assert args.max_seq_length == sink_size + window_size
275+
276+
return AttentionSinkEvalWrapper(
277+
model=model,
278+
tokenizer=tokenizer,
279+
sink_size=sink_size,
280+
window_size=window_size,
281+
max_seq_length=attention_sink_eval_length,
282+
use_kv_cache=args.use_kv_cache,
283+
)
284+
228285
return EagerEvalWrapper(
229286
model=model,
230287
tokenizer=tokenizer,
@@ -279,6 +336,12 @@ def build_args_parser() -> argparse.ArgumentParser:
279336
default=None,
280337
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
281338
)
339+
parser.add_argument(
340+
"--attention_sink_eval_length",
341+
type=int,
342+
default=2048,
343+
help="The maximum length of the sequence to evaluate with attention sink.",
344+
)
282345

283346
return parser
284347

0 commit comments

Comments
 (0)