Skip to content

Commit 7e3df29

Browse files
authored
Update torch2coreml.py
1 parent 4ef83f1 commit 7e3df29

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ def bundle_resources_for_swift_cli(args):
322322
from transformers.models.clip import modeling_clip
323323

324324
# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
325+
# Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
326+
# For backward compatibility with versions < 4.35.0, both functions are patched here.
325327
def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length: int = 0):
326328
""" Patch to replace torch.finfo(dtype).min with -1e4
327329
"""
@@ -334,9 +336,7 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
334336
if past_key_values_length > 0:
335337
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
336338
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
337-
338-
# Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
339-
# For backward compatibility with versions < 4.35.0, both functions are patched here.
339+
340340
modeling_clip._make_causal_mask = patched_make_causal_mask # For transformers < 4.35.0
341341
modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
342342

0 commit comments

Comments
 (0)