@@ -155,6 +155,44 @@ def _model_call(self, inps):
155155 pass
156156
157157
158+ class AttentionSinkEvalWrapper (EagerEvalWrapper ):
159+ """
160+ A wrapper class for evaluating the model with attention sink.
161+ """
162+
163+ def __init__ (
164+ self ,
165+ model : torch .nn .Module ,
166+ tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
167+ sink_size : int ,
168+ window_size : int ,
169+ max_seq_length : Optional [int ] = None ,
170+ use_kv_cache : bool = False ,
171+ ):
172+ super ().__init__ (model , tokenizer , max_seq_length , use_kv_cache )
173+ self .cache_size = sink_size + window_size
174+ assert self ._use_kv_cache , "Attention sink only works with kv cache."
175+
176+ def _model_call (self , inps ):
177+ # Given inps (tokens), return the logits
178+
179+ # Example:
180+ # inps: Tensor of shape (1, N)
181+ # logits: Tensor of shape (1, N, vocab_size)
182+ _ , seq_len = inps .shape
183+ result = self ._model (
184+ inps [:, : min (seq_len , self .cache_size )],
185+ torch .tensor ([0 ], dtype = torch .int64 , device = self .device ),
186+ )
187+ for pos in range (min (seq_len , self .cache_size ), seq_len ):
188+ logits = self ._model (
189+ inps [:, pos : pos + 1 ],
190+ torch .tensor ([pos ], dtype = torch .int64 , device = self .device ),
191+ )
192+ result = torch .cat (result , logits [:, - 1 , :], dim = 1 )
193+ return result
194+
195+
158196def gen_eval_wrapper (
159197 model_name : str ,
160198 args : argparse .ArgumentParser ,
@@ -225,6 +263,25 @@ def gen_eval_wrapper(
225263 if args .output_eager_checkpoint_file is not None : # pyre-ignore
226264 torch .save (model , args .output_eager_checkpoint_file )
227265
266+ if (use_attention_sink := args .use_attention_sink ) is not None and (
267+ attention_sink_eval_length := args .attention_sink_eval_length
268+ ) is not None : # pyre-ignore
269+ attention_sink_params = use_attention_sink .split ("," )
270+ assert len (attention_sink_params ) == 3
271+ sink_size = int (attention_sink_params [0 ])
272+ window_size = int (attention_sink_params [1 ])
273+
274+ assert args .max_seq_length == sink_size + window_size
275+
276+ return AttentionSinkEvalWrapper (
277+ model = model ,
278+ tokenizer = tokenizer ,
279+ sink_size = sink_size ,
280+ window_size = window_size ,
281+ max_seq_length = attention_sink_eval_length ,
282+ use_kv_cache = args .use_kv_cache ,
283+ )
284+
228285 return EagerEvalWrapper (
229286 model = model ,
230287 tokenizer = tokenizer ,
@@ -279,6 +336,12 @@ def build_args_parser() -> argparse.ArgumentParser:
279336 default = None ,
280337 help = "Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint." ,
281338 )
339+ parser .add_argument (
340+ "--attention_sink_eval_length" ,
341+ type = int ,
342+ default = 2048 ,
343+ help = "The maximum length of the sequence to evaluate with attention sink." ,
344+ )
282345
283346 return parser
284347
0 commit comments