Skip to content

Commit f5030e2

Browse files
Add progress bar to ace step. (#12242)
1 parent 66e1b07 commit f5030e2

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

comfy/text_encoders/ace15.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from comfy import sd1_clip
44
import torch
55
import math
6+
import comfy.utils
67

78

89
def sample_manual_loop_no_classes(
@@ -42,6 +43,8 @@ def sample_manual_loop_no_classes(
4243
for x in range(model_config.num_hidden_layers):
4344
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
4445

46+
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
47+
4548
for step in range(max_new_tokens):
4649
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
4750
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
@@ -90,6 +93,7 @@ def sample_manual_loop_no_classes(
9093
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
9194

9295
output_audio_codes.append(token - audio_start_id)
96+
progress_bar.update_absolute(step)
9397

9498
return output_audio_codes
9599

0 commit comments

Comments
 (0)