Skip to content

Commit ebff5f7

Browse files
author
Atila Orhon
committed
fixes #246
1 parent 35613e5 commit ebff5f7

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,24 @@ def bundle_resources_for_swift_cli(args):
272272
return resources_dir
273273

274274

275+
from transformers.models.clip import modeling_clip
276+
277+
# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
278+
def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length: int = 0):
279+
""" Patch to replace torch.finfo(dtype).min with -1e4
280+
"""
281+
bsz, tgt_len = input_ids_shape
282+
mask = torch.full((tgt_len, tgt_len), torch.tensor(-1e4, device=device), device=device)
283+
mask_cond = torch.arange(mask.size(-1), device=device)
284+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
285+
mask = mask.to(dtype)
286+
287+
if past_key_values_length > 0:
288+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
289+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
290+
291+
modeling_clip._make_causal_mask = patched_make_causal_mask
292+
275293
def convert_text_encoder(text_encoder, tokenizer, submodule_name, args):
276294
""" Converts the text encoder component of Stable Diffusion
277295
"""

0 commit comments

Comments
 (0)