Skip to content

Commit 4ef83f1

Browse files
authored
Update torch2coreml.py
1 parent 2447030 commit 4ef83f1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,10 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
335335
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
336336
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
337337

338-
modeling_clip._make_causal_mask = patched_make_causal_mask
339-
modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask
338+
# Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
339+
# For backward compatibility with versions < 4.35.0, both functions are patched here.
340+
modeling_clip._make_causal_mask = patched_make_causal_mask # For transformers < 4.35.0
341+
modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
340342

341343
def convert_text_encoder(text_encoder, tokenizer, submodule_name, args):
342344
""" Converts the text encoder component of Stable Diffusion

0 commit comments

Comments
 (0)