Skip to content

Commit 2447030

Browse files
authored
Update torch2coreml.py
1 parent bef26ae commit 2447030

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
336336
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
337337

338338
modeling_clip._make_causal_mask = patched_make_causal_mask
339+
modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask
339340

340341
def convert_text_encoder(text_encoder, tokenizer, submodule_name, args):
341342
""" Converts the text encoder component of Stable Diffusion

0 commit comments

Comments
 (0)