@@ -322,6 +322,8 @@ def bundle_resources_for_swift_cli(args):
322
322
from transformers .models .clip import modeling_clip
323
323
324
324
# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
325
+ # Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
326
+ # For backward compatibility with versions < 4.35.0, both functions are patched here.
325
327
def patched_make_causal_mask (input_ids_shape , dtype , device , past_key_values_length : int = 0 ):
326
328
""" Patch to replace torch.finfo(dtype).min with -1e4
327
329
"""
@@ -334,9 +336,7 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
334
336
if past_key_values_length > 0 :
335
337
mask = torch .cat ([torch .zeros (tgt_len , past_key_values_length , dtype = dtype , device = device ), mask ], dim = - 1 )
336
338
return mask [None , None , :, :].expand (bsz , 1 , tgt_len , tgt_len + past_key_values_length )
337
-
338
- # Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
339
- # For backward compatibility with versions < 4.35.0, both functions are patched here.
339
+
340
340
modeling_clip ._make_causal_mask = patched_make_causal_mask # For transformers < 4.35.0
341
341
modeling_clip ._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
342
342
0 commit comments