diff --git a/IPAdapterPlus.py b/IPAdapterPlus.py index ce3517a..ab4270f 100644 --- a/IPAdapterPlus.py +++ b/IPAdapterPlus.py @@ -213,6 +213,7 @@ def ipadapter_execute(model, clipvision, insightface=None, image=None, + style_image_mask=None, image_composition=None, image_negative=None, weight=1.0, @@ -355,8 +356,17 @@ def ipadapter_execute(model, image = torch.stack(image) del image_iface, face + # style image mask + if style_image_mask is not None and style_image_mask.shape[1:3] != torch.Size([clipvision_size, clipvision_size]): + mask = style_image_mask.unsqueeze(1) + transforms = T.Compose([ + T.CenterCrop(min(mask.shape[2], mask.shape[3])), + T.Resize((clipvision_size, clipvision_size), interpolation=T.InterpolationMode.BICUBIC, antialias=True), + ]) + style_image_mask = transforms(mask).squeeze(1) + if image is not None: - img_cond_embeds = encode_image_masked(clipvision, image, batch_size=encode_batch_size, tiles=enhance_tiles, ratio=enhance_ratio, clipvision_size=clipvision_size) + img_cond_embeds = encode_image_masked(clipvision, image, mask=style_image_mask, batch_size=encode_batch_size, tiles=enhance_tiles, ratio=enhance_ratio, clipvision_size=clipvision_size) if image_composition is not None: img_comp_cond_embeds = encode_image_masked(clipvision, image_composition, batch_size=encode_batch_size, tiles=enhance_tiles, ratio=enhance_ratio, clipvision_size=clipvision_size) @@ -747,6 +757,7 @@ def INPUT_TYPES(s): "image_negative": ("IMAGE",), "attn_mask": ("MASK",), "clip_vision": ("CLIP_VISION",), + "style_image_mask": ("MASK",), } } @@ -754,7 +765,7 @@ def INPUT_TYPES(s): FUNCTION = "apply_ipadapter" CATEGORY = "ipadapter" - def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0, weight_style=1.0, weight_composition=1.0, expand_style=False, weight_type="linear", combine_embeds="concat", weight_faceidv2=None, image=None, image_style=None, image_composition=None, image_negative=None, clip_vision=None, attn_mask=None, insightface=None, embeds_scaling='V only', layer_weights=None, ipadapter_params=None, encode_batch_size=0, style_boost=None, composition_boost=None, enhance_tiles=1, enhance_ratio=1.0, weight_kolors=1.0): + def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0, weight_style=1.0, weight_composition=1.0, expand_style=False, weight_type="linear", combine_embeds="concat", weight_faceidv2=None, image=None, style_image_mask=None, image_style=None, image_composition=None, image_negative=None, clip_vision=None, attn_mask=None, insightface=None, embeds_scaling='V only', layer_weights=None, ipadapter_params=None, encode_batch_size=0, style_boost=None, composition_boost=None, enhance_tiles=1, enhance_ratio=1.0, weight_kolors=1.0): is_sdxl = isinstance(model.model, (comfy.model_base.SDXL, comfy.model_base.SDXLRefiner, comfy.model_base.SDXL_instructpix2pix)) if 'ipadapter' in ipadapter: @@ -797,6 +808,7 @@ def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0 ipa_args = { "image": image[i], + "style_image_mask": style_image_mask if not isinstance(style_image_mask, list) else style_image_mask[i], "image_composition": image_composition, "image_negative": image_negative, "weight": weight[i],