Skip to content

Commit 406155c

Browse files
authored
Merge PR #503 from Kosinkadink/develop
Fixed PromptScheduling AssertionError
2 parents edb939f + 268f7fb commit 406155c

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

animatediff/scheduling.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import Tensor
66
import torch.nn.functional as F
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, replace
88

99
from comfy.sd import CLIP
1010
from comfy.utils import ProgressBar
@@ -291,16 +291,21 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
291291
prev_holder: Union[CondHolder, None] = None
292292
for idx, pair in enumerate(pairs):
293293
holder = None
294+
is_over_length = False
294295
# if no last pair is set, then use first provided val up to the idx
295296
if prev_holder is None:
296297
for i in range(idx, pair.idx+1):
297298
if i >= length:
299+
is_over_length = True
298300
continue
299301
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace)
300302
if holder is None or holder.prompt != real_prompt:
301303
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
302304
cond = pad_cond(cond, target_length=max_size)
303305
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
306+
else:
307+
holder = replace(holder)
308+
holder.idx = i
304309
real_cond[i] = cond
305310
real_pooled[i] = pooled
306311
real_holders[i] = holder
@@ -326,6 +331,7 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
326331
# however, need to check if real_prompt remains the same
327332
for i in range(prev_holder.idx+1, pair.idx):
328333
if i >= length:
334+
is_over_length = True
329335
continue
330336
if holder is None:
331337
holder = prev_holder
@@ -334,6 +340,9 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
334340
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
335341
cond = pad_cond(cond, target_length=max_size)
336342
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
343+
else:
344+
holder = replace(holder)
345+
holder.idx = i
337346
real_cond[i] = holder.cond
338347
real_pooled[i] = holder.pooled
339348
real_holders[i] = holder
@@ -361,16 +370,17 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
361370
cond_from = None
362371
holder = None
363372
interm_holder = prev_holder
364-
for idx, weight in zip(interp_idxs, interp_weights):
365-
if idx >= length:
373+
for raw_idx, weight in zip(interp_idxs, interp_weights):
374+
if raw_idx >= length:
375+
is_over_length = True
366376
continue
367-
idx_int = round(float(idx))
377+
idx_int = round(float(raw_idx))
368378
# calculate cond_to stuff if not done yet
369379
real_prompt = apply_values_replace_to_prompt(pair.val, idx_int, values_replace=values_replace)
370380
if holder is None or holder.prompt != real_prompt:
371381
cond_to, pooled_to = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
372382
cond_to = pad_cond(cond_to, target_length=max_size)
373-
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold)
383+
holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold)
374384
# calculate interm_holder stuff if needed
375385
real_prompt = apply_values_replace_to_prompt(interm_holder.raw_prompt, idx_int, values_replace=values_replace)
376386
if interm_holder.prompt != real_prompt:
@@ -394,6 +404,8 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
394404
real_holders[idx_int] = interm_holder
395405
pbar.update(1)
396406
comfy.model_management.throw_exception_if_processing_interrupted()
407+
if is_over_length:
408+
break
397409
assert holder is not None
398410
prev_holder = holder
399411

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.3.1"
4+
version = "1.3.2"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)