@@ -378,39 +378,35 @@ def _compute_self_attention_mask(
378
378
decoder_sequence ,
379
379
decoder_padding_mask ,
380
380
decoder_attention_mask ,
381
- use_causal_mask ,
382
381
self_attention_cache ,
383
382
self_attention_cache_update_index ,
384
383
):
385
384
decoder_mask = merge_padding_and_attention_mask (
386
385
decoder_sequence , decoder_padding_mask , decoder_attention_mask
387
386
)
388
- if use_causal_mask :
389
- batch_size = ops .shape (decoder_sequence )[0 ]
390
- input_length = output_length = ops .shape (decoder_sequence )[1 ]
391
- # We need to handle a rectangular causal mask when doing cached
392
- # decoding. For generative inference, `decoder_sequence` will
393
- # generally be length 1, and `cache` will be the full generation
394
- # length.
395
- if self_attention_cache is not None :
396
- input_length = ops .shape (self_attention_cache )[2 ]
397
-
398
- causal_mask = compute_causal_mask (
399
- batch_size ,
400
- input_length ,
401
- output_length ,
402
- (
403
- 0
404
- if self_attention_cache_update_index is None
405
- else self_attention_cache_update_index
406
- ),
407
- )
408
- return (
409
- ops .minimum (decoder_mask , causal_mask )
410
- if decoder_mask is not None
411
- else causal_mask
412
- )
413
- return decoder_mask
387
+ batch_size = ops .shape (decoder_sequence )[0 ]
388
+ input_length = output_length = ops .shape (decoder_sequence )[1 ]
389
+ # We need to handle a rectangular causal mask when doing cached
390
+ # decoding. For generative inference, `decoder_sequence` will
391
+ # generally be length 1, and `cache` will be the full generation length.
392
+ if self_attention_cache is not None :
393
+ input_length = ops .shape (self_attention_cache )[2 ]
394
+
395
+ cache_update_index = (
396
+ 0
397
+ if self_attention_cache_update_index is None
398
+ else self_attention_cache_update_index
399
+ )
400
+
401
+ causal_mask = compute_causal_mask (
402
+ batch_size , input_length , output_length , cache_update_index
403
+ )
404
+
405
+ return (
406
+ ops .minimum (decoder_mask , causal_mask )
407
+ if decoder_mask is not None
408
+ else causal_mask
409
+ )
414
410
415
411
def build (self , input_shape ):
416
412
"""
0 commit comments