@@ -358,6 +358,8 @@ def bundle_resources_for_swift_cli(args):
358
358
from transformers .models .clip import modeling_clip
359
359
360
360
# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
361
+ # Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
362
+ # For backward compatibility with versions < 4.35.0, both functions are patched here.
361
363
def patched_make_causal_mask (input_ids_shape , dtype , device , past_key_values_length : int = 0 ):
362
364
""" Patch to replace torch.finfo(dtype).min with -1e4
363
365
"""
@@ -370,8 +372,9 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
370
372
if past_key_values_length > 0 :
371
373
mask = torch .cat ([torch .zeros (tgt_len , past_key_values_length , dtype = dtype , device = device ), mask ], dim = - 1 )
372
374
return mask [None , None , :, :].expand (bsz , 1 , tgt_len , tgt_len + past_key_values_length )
373
-
374
- modeling_clip ._make_causal_mask = patched_make_causal_mask
375
+
376
+ modeling_clip ._make_causal_mask = patched_make_causal_mask # For transformers >= 4.30.0 and transformers < 4.35.0
377
+ modeling_clip ._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
375
378
376
379
def convert_text_encoder (text_encoder , tokenizer , submodule_name , args ):
377
380
""" Converts the text encoder component of Stable Diffusion
0 commit comments