Skip to content

Commit e12202c

Browse files
authored
Merge pull request #368 from Aintor/main
Fix for Mask Patch Failure and Quantization Issues in Latest `transformers` Versions
2 parents e17915a + 50bb4c4 commit e12202c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ def bundle_resources_for_swift_cli(args):
358358
from transformers.models.clip import modeling_clip
359359

360360
# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
361+
# Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
362+
# For backward compatibility with versions < 4.35.0, both functions are patched here.
361363
def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length: int = 0):
362364
""" Patch to replace torch.finfo(dtype).min with -1e4
363365
"""
@@ -370,8 +372,9 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
370372
if past_key_values_length > 0:
371373
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
372374
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
373-
374-
modeling_clip._make_causal_mask = patched_make_causal_mask
375+
376+
modeling_clip._make_causal_mask = patched_make_causal_mask # For transformers >= 4.30.0 and transformers < 4.35.0
377+
modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
375378

376379
def convert_text_encoder(text_encoder, tokenizer, submodule_name, args):
377380
""" Converts the text encoder component of Stable Diffusion

0 commit comments

Comments
 (0)