44import torch
55from torch import Tensor
66import torch .nn .functional as F
7- from dataclasses import dataclass
7+ from dataclasses import dataclass , replace
88
99from comfy .sd import CLIP
1010from comfy .utils import ProgressBar
@@ -291,16 +291,21 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
291291 prev_holder : Union [CondHolder , None ] = None
292292 for idx , pair in enumerate (pairs ):
293293 holder = None
294+ is_over_length = False
294295 # if no last pair is set, then use first provided val up to the idx
295296 if prev_holder is None :
296297 for i in range (idx , pair .idx + 1 ):
297298 if i >= length :
299+ is_over_length = True
298300 continue
299301 real_prompt = apply_values_replace_to_prompt (pair .val , i , values_replace = values_replace )
300302 if holder is None or holder .prompt != real_prompt :
301303 cond , pooled = clip .encode_from_tokens (clip .tokenize (real_prompt ), return_pooled = True )
302304 cond = pad_cond (cond , target_length = max_size )
303305 holder = CondHolder (idx = i , prompt = real_prompt , raw_prompt = pair .val , cond = cond , pooled = pooled , hold = pair .hold )
306+ else :
307+ holder = replace (holder )
308+ holder .idx = i
304309 real_cond [i ] = cond
305310 real_pooled [i ] = pooled
306311 real_holders [i ] = holder
@@ -326,6 +331,7 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
326331 # however, need to check if real_prompt remains the same
327332 for i in range (prev_holder .idx + 1 , pair .idx ):
328333 if i >= length :
334+ is_over_length = True
329335 continue
330336 if holder is None :
331337 holder = prev_holder
@@ -334,6 +340,9 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
334340 cond , pooled = clip .encode_from_tokens (clip .tokenize (real_prompt ), return_pooled = True )
335341 cond = pad_cond (cond , target_length = max_size )
336342 holder = CondHolder (idx = i , prompt = real_prompt , raw_prompt = pair .val , cond = cond , pooled = pooled , hold = pair .hold )
343+ else :
344+ holder = replace (holder )
345+ holder .idx = i
337346 real_cond [i ] = holder .cond
338347 real_pooled [i ] = holder .pooled
339348 real_holders [i ] = holder
@@ -361,16 +370,17 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
361370 cond_from = None
362371 holder = None
363372 interm_holder = prev_holder
364- for idx , weight in zip (interp_idxs , interp_weights ):
365- if idx >= length :
373+ for raw_idx , weight in zip (interp_idxs , interp_weights ):
374+ if raw_idx >= length :
375+ is_over_length = True
366376 continue
367- idx_int = round (float (idx ))
377+ idx_int = round (float (raw_idx ))
368378 # calculate cond_to stuff if not done yet
369379 real_prompt = apply_values_replace_to_prompt (pair .val , idx_int , values_replace = values_replace )
370380 if holder is None or holder .prompt != real_prompt :
371381 cond_to , pooled_to = clip .encode_from_tokens (clip .tokenize (real_prompt ), return_pooled = True )
372382 cond_to = pad_cond (cond_to , target_length = max_size )
373- holder = CondHolder (idx = pair . idx , prompt = real_prompt , raw_prompt = pair .val , cond = cond_to , pooled = pooled_to , hold = pair .hold )
383+ holder = CondHolder (idx = idx_int , prompt = real_prompt , raw_prompt = pair .val , cond = cond_to , pooled = pooled_to , hold = pair .hold )
374384 # calculate interm_holder stuff if needed
375385 real_prompt = apply_values_replace_to_prompt (interm_holder .raw_prompt , idx_int , values_replace = values_replace )
376386 if interm_holder .prompt != real_prompt :
@@ -394,6 +404,8 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
394404 real_holders [idx_int ] = interm_holder
395405 pbar .update (1 )
396406 comfy .model_management .throw_exception_if_processing_interrupted ()
407+ if is_over_length :
408+ break
397409 assert holder is not None
398410 prev_holder = holder
399411
0 commit comments