@@ -439,7 +439,6 @@ def forward(
439439 """
440440 # 1) Project key, value, and query.
441441 # as a reminder at training layer_cache[0] remains False
442- key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
443442 if self .layer_cache [0 ]:
444443 # Retrieve keys and values from the KV cache (decoding mode only).
445444 if self .attn_type == "self" :
@@ -484,6 +483,16 @@ def forward(
484483 key = key [:, :, 1 :, :]
485484 value = value [:, :, 1 :, :]
486485
486+ if step == 0 :
487+ key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
488+ if key_pad_mask is not None :
489+ x = key_pad_mask .expand (
490+ - 1 , self .head_count // self .parallel_gpu , - 1
491+ )
492+ x = x .unsqueeze (3 )
493+ x = x .expand (- 1 , - 1 , - 1 , value .size (3 ))
494+ value = value .masked_fill (x , 0 )
495+
487496 self .layer_cache [1 ]["keys" ] = key
488497 self .layer_cache [1 ]["values" ] = value
489498
@@ -565,19 +574,6 @@ def forward(
565574 self .layer_cache [1 ]["keys" ] = key
566575 self .layer_cache [1 ]["values" ] = value
567576
568- if key_pad_mask is not None :
569- # Increase the cached key pad mask by concatenation.
570- # For decoding only.
571- if step > 0 :
572- y = torch .zeros (
573- (key_pad_mask .size (0 ), key_pad_mask .size (1 ), 1 ),
574- dtype = torch .bool ,
575- device = key_pad_mask .device ,
576- )
577- self .layer_cache [1 ]["key_pad_mask" ] = torch .cat (
578- (key_pad_mask , y ), 2
579- )
580- key_pad_mask = self .layer_cache [1 ]["key_pad_mask" ]
581577 else :
582578 # Retrieve keys and values from linear layers (training mode).
583579 key = self .maybe_ckpt (self .linear_keys , key )
@@ -706,8 +702,6 @@ def forward(
706702 scores = self .alibi (scores )
707703
708704 scores = scores .float ()
709- if key_pad_mask is not None and mask is None :
710- mask = key_pad_mask .unsqueeze (1 )
711705
712706 if mask is not None :
713707 # not 100% necessary but expand to nb of heads
@@ -727,10 +721,6 @@ def forward(
727721 attn_output .add_ (relative_matmul (drop_attn , relations_values , False ))
728722
729723 context = unshape (attn_output )
730- if key_pad_mask is not None :
731- if key_pad_mask .size (0 ) > 1 and context .size (1 ) > 1 :
732- x = key_pad_mask .squeeze (1 ).unsqueeze (2 ).expand (- 1 , - 1 , context .size (2 ))
733- context = context .masked_fill (x , 0 )
734724
735725 if self .layer_cache [0 ]:
736726 attn_output = self .final_linear (context )
0 commit comments