diff --git a/CrossAttentionPatch.py b/CrossAttentionPatch.py index ac2d0fe..225beb5 100644 --- a/CrossAttentionPatch.py +++ b/CrossAttentionPatch.py @@ -6,7 +6,7 @@ class CrossAttentionPatch: # forward for patching - def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'): + def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, image_schedule=None, embeds_scaling='V only'): self.weights = [weight] self.ipadapters = [ipadapter] self.conds = [cond] @@ -17,6 +17,7 @@ def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=Non self.sigma_starts = [sigma_start] self.sigma_ends = [sigma_end] self.unfold_batch = [unfold_batch] + self.image_schedule = [image_schedule] self.embeds_scaling = [embeds_scaling] self.number = number self.layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 # TODO: check if this is a valid condition to detect all models @@ -24,7 +25,7 @@ def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=Non self.k_key = str(self.number*2+1) + "_to_k_ip" self.v_key = str(self.number*2+1) + "_to_v_ip" - def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'): + def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, image_schedule=None, embeds_scaling='V only'): self.weights.append(weight) self.ipadapters.append(ipadapter) self.conds.append(cond) @@ -35,6 +36,7 @@ def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, con self.sigma_starts.append(sigma_start) self.sigma_ends.append(sigma_end) self.unfold_batch.append(unfold_batch) + self.image_schedule.append(image_schedule) self.embeds_scaling.append(embeds_scaling) def __call__(self, q, k, v, extra_options): @@ -54,7 +56,7 @@ def __call__(self, q, k, v, extra_options): out = optimized_attention(q, k, v, extra_options["n_heads"]) _, _, oh, ow = extra_options["original_shape"] - for weight, cond, cond_alt, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, embeds_scaling in zip(self.weights, self.conds, self.conds_alt, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.embeds_scaling): + for weight, cond, cond_alt, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, image_schedule, embeds_scaling in zip(self.weights, self.conds, self.conds_alt, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.image_schedule, self.embeds_scaling): if sigma <= sigma_start and sigma >= sigma_end: if weight_type == 'ease in': weight = weight * (0.05 + 0.95 * (1 - t_idx / self.layers)) @@ -94,16 +96,23 @@ def __call__(self, q, k, v, extra_options): elif weight == 0: continue - # if image length matches or exceeds full_length get sub_idx images - if cond.shape[0] >= ad_params["full_length"]: - cond = torch.Tensor(cond[ad_params["sub_idxs"]]) - uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) - # otherwise get sub_idxs images + if image_schedule is not None: + # Use the image_schedule as a lookup table to get the embedded image corresponding to each sub_idx + # If image_schedule isn't long enough then use the last image + cond_idxs = [image_schedule[i if i < len(image_schedule) else -1] for i in ad_params["sub_idxs"]] + cond = torch.Tensor(cond[cond_idxs]) + uncond = torch.Tensor(uncond[cond_idxs]) else: - cond = tensor_to_size(cond, ad_params["full_length"]) - uncond = tensor_to_size(uncond, ad_params["full_length"]) - cond = cond[ad_params["sub_idxs"]] - uncond = uncond[ad_params["sub_idxs"]] + # if image length matches or exceeds full_length get sub_idx images + if cond.shape[0] >= ad_params["full_length"]: + cond = torch.Tensor(cond[ad_params["sub_idxs"]]) + uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) + # otherwise get sub_idxs images + else: + cond = tensor_to_size(cond, ad_params["full_length"]) + uncond = tensor_to_size(uncond, ad_params["full_length"]) + cond = cond[ad_params["sub_idxs"]] + uncond = uncond[ad_params["sub_idxs"]] else: if isinstance(weight, torch.Tensor): weight = tensor_to_size(weight, batch_prompt) diff --git a/IPAdapterPlus.py b/IPAdapterPlus.py index 65b95a7..a40224e 100644 --- a/IPAdapterPlus.py +++ b/IPAdapterPlus.py @@ -163,6 +163,7 @@ def ipadapter_execute(model, pos_embed=None, neg_embed=None, unfold_batch=False, + image_schedule=None, embeds_scaling='V only', layer_weights=None): device = model_management.get_torch_device() @@ -374,6 +375,7 @@ def ipadapter_execute(model, "sigma_start": sigma_start, "sigma_end": sigma_end, "unfold_batch": unfold_batch, + "image_schedule": image_schedule, "embeds_scaling": embeds_scaling, } @@ -634,7 +636,7 @@ def INPUT_TYPES(s): FUNCTION = "apply_ipadapter" CATEGORY = "ipadapter" - def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0, weight_style=1.0, weight_composition=1.0, expand_style=False, weight_type="linear", combine_embeds="concat", weight_faceidv2=None, image=None, image_style=None, image_composition=None, image_negative=None, clip_vision=None, attn_mask=None, insightface=None, embeds_scaling='V only', layer_weights=None, ipadapter_params=None): + def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0, weight_style=1.0, weight_composition=1.0, expand_style=False, weight_type="linear", combine_embeds="concat", weight_faceidv2=None, image=None, image_style=None, image_composition=None, image_negative=None, clip_vision=None, image_schedule=None, attn_mask=None, insightface=None, embeds_scaling='V only', layer_weights=None, ipadapter_params=None): is_sdxl = isinstance(model.model, (comfy.model_base.SDXL, comfy.model_base.SDXLRefiner, comfy.model_base.SDXL_instructpix2pix)) if 'ipadapter' in ipadapter: @@ -688,6 +690,7 @@ def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0 "end_at": end_at if not isinstance(end_at, list) else end_at[i], "attn_mask": attn_mask if not isinstance(attn_mask, list) else attn_mask[i], "unfold_batch": self.unfold_batch, + "image_schedule": image_schedule, "embeds_scaling": embeds_scaling, "insightface": insightface if insightface is not None else ipadapter['insightface']['model'] if 'insightface' in ipadapter else None, "layer_weights": layer_weights, @@ -719,6 +722,7 @@ def INPUT_TYPES(s): "image_negative": ("IMAGE",), "attn_mask": ("MASK",), "clip_vision": ("CLIP_VISION",), + "image_schedule": ("INT", {"default": None, "forceInput": True} ), } } @@ -771,6 +775,7 @@ def INPUT_TYPES(s): "image_negative": ("IMAGE",), "attn_mask": ("MASK",), "clip_vision": ("CLIP_VISION",), + "image_schedule": ("INT", {"default": None, "forceInput": True} ), } } @@ -963,6 +968,7 @@ def INPUT_TYPES(s): "image_negative": ("IMAGE",), "attn_mask": ("MASK",), "clip_vision": ("CLIP_VISION",), + "image_schedule": ("INT", {"default": None, "forceInput": True} ), } } @@ -1330,6 +1336,37 @@ def load(self, embeds): path = folder_paths.get_annotated_filepath(embeds) return (torch.load(path).cpu(), ) +defaultValue="""0:0, +40:1, +80:2, +""" +class IPAdapterImageSchedule: + @classmethod + def INPUT_TYPES(s): + return {"required": {"text": ("STRING", {"multiline": True, "default": defaultValue}), + "max_frames": ("INT", {"default": 120.0, "min": 1.0, "max": 999999.0, "step": 1.0}), + "print_output": ("BOOLEAN", {"default": False})}} + + RETURN_TYPES = ("INT",) + FUNCTION = "schedule" + + CATEGORY = "ipadapter/utils" + + def schedule(self, text, max_frames, print_output): + frames = [0] * max_frames + for item in text.split(","): + item = item.strip() + if ":" in item: + parts = item.split(":") + if len(parts) == 2: + start_frame = int(parts[0]) + value = int(parts[1]) + for i in range(start_frame, max_frames): + frames[i] = value + if print_output is True: + print("ValueSchedule: ", frames) + return (frames, ) + class IPAdapterWeights: @classmethod def INPUT_TYPES(s): @@ -1521,8 +1558,10 @@ def combine(self, params_1, params_2, params_3=None, params_4=None, params_5=Non "IPAdapterSaveEmbeds": IPAdapterSaveEmbeds, "IPAdapterLoadEmbeds": IPAdapterLoadEmbeds, "IPAdapterWeights": IPAdapterWeights, + "IPAdapterImageSchedule": IPAdapterImageSchedule, "IPAdapterRegionalConditioning": IPAdapterRegionalConditioning, "IPAdapterCombineParams": IPAdapterCombineParams, + "IPAdapterCombineParams": IPAdapterCombineParams, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1555,6 +1594,7 @@ def combine(self, params_1, params_2, params_3=None, params_4=None, params_5=Non "IPAdapterSaveEmbeds": "IPAdapter Save Embeds", "IPAdapterLoadEmbeds": "IPAdapter Load Embeds", "IPAdapterWeights": "IPAdapter Weights", + "IPAdapterImageSchedule": "IPAdapterImageSchedule", "IPAdapterRegionalConditioning": "IPAdapter Regional Conditioning", "IPAdapterCombineParams": "IPAdapter Combine Params", } \ No newline at end of file diff --git a/utils.py b/utils.py index 5d1811a..01f8ff5 100644 --- a/utils.py +++ b/utils.py @@ -154,21 +154,34 @@ def insightface_loader(provider): model.prepare(ctx_id=0, det_size=(640, 640)) return model -def encode_image_masked(clip_vision, image, mask=None): +def encode_image_masked(clip_vision, images, mask=None): model_management.load_model_gpu(clip_vision.patcher) - image = image.to(clip_vision.load_device) - pixel_values = clip_preprocess(image.to(clip_vision.load_device)).float() + # Initialize lists to collect outputs + last_hidden_states = [] + image_embeds = [] + penultimate_hidden_states = [] - if mask is not None: - pixel_values = pixel_values * mask.to(clip_vision.load_device) + # Loop over each image in the batch + for image in images: + pixel_values = clip_preprocess(image.to(clip_vision.load_device).unsqueeze(0)).float() - out = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) + if mask is not None: + pixel_values *= mask.to(clip_vision.load_device) + out = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) + + # Collect the outputs for each image + last_hidden_states.append(out[0].to(model_management.intermediate_device())) + image_embeds.append(out[2].to(model_management.intermediate_device())) + penultimate_hidden_states.append(out[1].to(model_management.intermediate_device())) + + # Concatenate all collected outputs across the batch outputs = Output() - outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device()) - outputs["image_embeds"] = out[2].to(model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device()) + outputs["last_hidden_state"] = torch.cat(last_hidden_states, dim=0) + outputs["image_embeds"] = torch.cat(image_embeds, dim=0) + outputs["penultimate_hidden_states"] = torch.cat(penultimate_hidden_states, dim=0) + return outputs def tensor_to_size(source, dest_size):