@@ -323,7 +323,10 @@ def build_mask_cache(
323
323
"""
324
324
# Usual causal mask:
325
325
mask = torch .ones (
326
- max_seq_length , max_seq_length , device = device , dtype = dtype ,
326
+ max_seq_length ,
327
+ max_seq_length ,
328
+ device = device ,
329
+ dtype = dtype ,
327
330
).triu (diagonal = 1 )
328
331
if sliding_window_size is not None :
329
332
mask += torch .ones_like (mask ).tril (diagonal = - sliding_window_size )
@@ -363,15 +366,23 @@ def build_mask_slice(
363
366
tp_dtype = token_positions .dtype
364
367
token_positions = token_positions .unsqueeze (2 ).to (device = device )
365
368
kwargs = dict (device = device , dtype = tp_dtype )
366
- bool_mask = torch .arange (
367
- input_pos , input_pos + num , ** kwargs ,
368
- ).view (1 , 1 , - 1 , 1 ) < token_positions
369
- if sliding_window_size is not None :
370
- extra_mask = torch .arange (
371
- input_pos - sliding_window_size ,
372
- input_pos + num - sliding_window_size ,
369
+ bool_mask = (
370
+ torch .arange (
371
+ input_pos ,
372
+ input_pos + num ,
373
373
** kwargs ,
374
- ).view (1 , 1 , - 1 , 1 ) >= token_positions
374
+ ).view (1 , 1 , - 1 , 1 )
375
+ < token_positions
376
+ )
377
+ if sliding_window_size is not None :
378
+ extra_mask = (
379
+ torch .arange (
380
+ input_pos - sliding_window_size ,
381
+ input_pos + num - sliding_window_size ,
382
+ ** kwargs ,
383
+ ).view (1 , 1 , - 1 , 1 )
384
+ >= token_positions
385
+ )
375
386
bool_mask += extra_mask
376
387
mask = torch .zeros (bool_mask .shape , dtype = dtype , device = device )
377
388
mask .masked_fill_ (bool_mask , torch .finfo (dtype ).min )
0 commit comments