Skip to content

Commit 50bb4c4

Browse files
authored
Update torch2coreml.py
1 parent 7e3df29 commit 50bb4c4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
337337
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
338338
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
339339

340-
modeling_clip._make_causal_mask = patched_make_causal_mask # For transformers < 4.35.0
340+
modeling_clip._make_causal_mask = patched_make_causal_mask # For transformers >= 4.30.0 and transformers < 4.35.0
341341
modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
342342

343343
def convert_text_encoder(text_encoder, tokenizer, submodule_name, args):

0 commit comments

Comments
 (0)