Skip to content

Commit 6412422

Browse files
committed
bug fixes + added some features
1 parent 233e441 commit 6412422

File tree

5 files changed

+126
-26
lines changed

5 files changed

+126
-26
lines changed

comfy/autoregressive_sampling.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
import copy
45
import torch
56
import inspect
@@ -56,9 +57,9 @@ def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
5657

5758
class TemperatureLogitsWarper:
5859
def __init__(self, temperature: float):
59-
6060
if not (temperature > 0):
61-
raise ValueError(f"`temperature` (={temperature}) must be positive temperature > 0")
61+
raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0")
62+
6263
self.temperature = temperature
6364

6465
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -86,10 +87,30 @@ def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
8687
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
8788
return scores_processed
8889

90+
class MinLengthLogitsProcessor:
91+
def __init__(self, min_length: int, eos_token_id: torch.Tensor):
92+
self.min_length = min_length
93+
self.eos_token_id = eos_token_id
94+
95+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
96+
97+
if input_ids is None:
98+
return scores
99+
100+
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
101+
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
102+
scores_processed = scores.clone()
103+
if input_ids.shape[-1] < self.min_length:
104+
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
105+
return scores_processed
106+
89107
def get_logits_processing(config: GenerationConfig):
90108
# TODO: add support for beam search with diversity penalty
91109
logits_processors = []
92110

111+
if config._eos_token_tensor is not None and config.min_length > 1:
112+
logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor))
113+
93114
if config.top_k is not None and config.top_k != 0:
94115
logits_processors.append(TopKLogits(config.top_k))
95116

@@ -101,28 +122,59 @@ def get_logits_processing(config: GenerationConfig):
101122

102123
return logits_processors
103124

104-
def apply_logits_processing(logits, logits_processing_list, **kwargs):
125+
def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs):
105126
for process in logits_processing_list:
106127
func_args = inspect.signature(process.__call__).parameters
107-
if not all(arg in kwargs for arg in list(func_args.keys())[1:]):
128+
if not all(arg in kwargs for arg in list(func_args.keys())[3:]):
108129
raise ValueError(
109130
f"Make sure that all the required parameters: {list(func_args.keys())} for "
110131
f"{process.__class__} are passed to the logits processor."
111132
)
112-
logits = process(logits, **kwargs)
133+
if "input_ids" in func_args:
134+
logits = process(input_ids, logits)
135+
else:
136+
logits = process(logits, **kwargs)
113137
return logits
114138

115-
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token):
139+
def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor:
140+
# stop_strings must be a list of lists: List[List[], List[]]
141+
142+
device = input_ids.device
143+
batch_size, seq_len = input_ids.shape
144+
finished = torch.zeros(batch_size, dtype = torch.bool, device = device)
145+
146+
for b in range(batch_size):
147+
row = input_ids[b]
148+
# check each stop token sequence
149+
for stop_ids in stop_strings:
150+
n = len(stop_ids)
151+
if n == 0 or n > seq_len:
152+
continue
153+
# compare tail of the generated ids to the stop sequence
154+
if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)):
155+
finished[b] = True
156+
break
157+
158+
return finished
159+
160+
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None):
161+
162+
device = input_ids.device
116163

117164
if not isinstance(eos_token, torch.Tensor):
118-
eos_token = torch.tensor(eos_token, device=input_ids.device)
165+
eos_token = torch.tensor(eos_token, device = device)
119166

120167
max_len_done = input_ids.shape[1] >= max_length
121168

122169
eos_done = torch.isin(input_ids[:, -1], eos_token)
123170

124-
# finished either by lenght or eos
125-
finished_mask = max_len_done | eos_done
171+
if stop_strings is not None:
172+
stop_done = check_stopping_strings(input_ids, stop_strings)
173+
else:
174+
stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device)
175+
176+
# finished either by lenght or eos or stop strings
177+
finished_mask = max_len_done | eos_done | stop_done
126178

127179
unfinished_mask = ~finished_mask
128180

comfy/ldm/higgsv2/model.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ def __init__(self, device = None, dtype = None, operations = None, **kwargs):
500500
torch.ones(kwargs["audio_num_codebooks"]) / kwargs["audio_num_codebooks"]
501501
)
502502

503+
self.stop_strings = [[128009], [128001]]
504+
503505
def _sample_audio_tokens(
504506
self,
505507
audio_logits: torch.Tensor,
@@ -520,7 +522,7 @@ def _sample_audio_tokens(
520522
audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None)
521523

522524
next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device)
523-
next_audio_token_scores = apply_logits_processing(next_audio_token_logits, logits_processing_list)
525+
next_audio_token_scores = apply_logits_processing(None, next_audio_token_logits, logits_processing_list)
524526

525527
if do_sample:
526528
probs = nn.functional.softmax(next_audio_token_scores, dim = -1)
@@ -588,14 +590,17 @@ def _sample_text_tokens(
588590
logits_processing_list,
589591
device: torch.device,
590592
generation_mode: GenerationMode,
593+
torch_generator,
594+
is_using_cuda_graphs,
595+
do_sample = False,
591596
) -> torch.Tensor:
592597
"""Sample text tokens from the logits"""
593598

594599
next_token_logits = logits.clone()[:, -1, :].float()
595600
next_token_logits = next_token_logits.to(input_ids.device)
596601

597602
# pre-process distribution
598-
next_token_scores = apply_logits_processing(next_token_logits, logits_processing_list)
603+
next_token_scores = apply_logits_processing(input_ids, next_token_logits, logits_processing_list)
599604

600605
if generation_mode == GenerationMode.AUDIO_INIT:
601606
# See the audio bos token, we should start generating audio tokens
@@ -612,7 +617,17 @@ def _sample_text_tokens(
612617
device=device,
613618
)
614619
else:
615-
next_tokens = torch.argmax(next_token_scores, dim=-1)
620+
621+
if do_sample:
622+
probs = nn.functional.softmax(next_token_scores, dim = -1)
623+
# same as for audio
624+
if not is_using_cuda_graphs:
625+
next_tokens = torch.multinomial(probs, num_samples = 1, generator=torch_generator).squeeze(1)
626+
else:
627+
next_tokens = categorical_sample(probs, generator = torch_generator)
628+
else:
629+
next_tokens = torch.argmax(next_token_scores, dim=-1)
630+
616631
next_audio_tokens = None
617632

618633
return next_tokens, next_audio_tokens
@@ -1093,12 +1108,6 @@ def _sample(
10931108
del model_inputs["audio_out_ids_start"]
10941109

10951110
if generation_config.use_cache:
1096-
if "audio_features" in model_inputs and model_inputs["audio_features"] is not None:
1097-
model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...]
1098-
model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][
1099-
:0, ...
1100-
]
1101-
11021111
if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None:
11031112
model_inputs["audio_in_ids"] = None
11041113
model_inputs["audio_in_ids_start"] = None
@@ -1159,6 +1168,9 @@ def _sample(
11591168
logits_processing_list=logits_processing_list,
11601169
device=input_ids.device,
11611170
generation_mode=generation_mode,
1171+
torch_generator = torch_generator,
1172+
do_sample = do_sample,
1173+
is_using_cuda_graphs = is_using_cuda_graphs
11621174
)
11631175

11641176
if next_audio_tokens is not None:
@@ -1199,7 +1211,7 @@ def _sample(
11991211
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
12001212

12011213
input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1)
1202-
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor)
1214+
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor, stop_strings = self.stop_strings)
12031215
this_peer_finished = finished.all()
12041216
cur_len += 1
12051217

comfy/text_encoders/higgsv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def decode_tokens(self, audio_tokens):
7171
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
7272
outputs.append(wv_numpy)
7373

74-
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
74+
# currently only supports one batch size
75+
return (None, {"waveform": torch.cat(outputs, dim = 0).unsqueeze(0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
7576

7677
def load_state_dict(self, sd, strict = False):
7778
return self.audio_tokenizer.load_state_dict(sd, strict = strict)

comfy_extras/nodes_audio.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
2323
)
2424

25-
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
2625

2726
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
2827
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
@@ -175,6 +174,8 @@ def convert_to_ml_format(self, clip, text, audio=None):
175174
current_role = None
176175
collecting_system = False
177176
system_buffer = []
177+
collecting_instruction = False
178+
instruction_buffer = []
178179

179180
for line in lines:
180181
line = line.strip()
@@ -197,6 +198,26 @@ def convert_to_ml_format(self, clip, text, audio=None):
197198
collecting_system = False
198199
continue
199200

201+
# generation instruction start
202+
if "<|generation_instruction_start|>" in line:
203+
collecting_instruction = True
204+
instruction_buffer = []
205+
continue
206+
207+
if collecting_instruction:
208+
if "<|generation_instruction_end|>" in line:
209+
instruction_text = "\n".join(instruction_buffer)
210+
# include both start and end tokens
211+
messages.append(Message(
212+
role="user",
213+
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
214+
))
215+
instruction_buffer = []
216+
collecting_instruction = False
217+
else:
218+
instruction_buffer.append(line)
219+
continue
220+
200221
# speaker lines SPEAKER-0: text
201222
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
202223
if match:
@@ -221,14 +242,29 @@ def convert_to_ml_format(self, clip, text, audio=None):
221242
lines = all_text.splitlines()
222243
messages = [messages[0]] if messages[0].role == "system" else []
223244
current_role = None
245+
224246
for line in lines:
225-
match = re.match(r'\[SPEAKER\d+\]', line)
247+
line = line.strip()
248+
if not line:
249+
continue
250+
251+
match = re.match(r'(\[SPEAKER\d+\])\s*(.*)', line)
226252
if match:
227-
current_role = match.group(0)
228-
messages.append(Message(role="user", content=line.strip()))
253+
current_role = match.group(1)
254+
content = match.group(2).strip() # only take the text after the tag
255+
messages.append(Message(role="user", content=f"{current_role} {content}" if content else current_role))
229256
else:
230257
if current_role and messages:
231-
messages[-1].content += "\n" + line.strip()
258+
messages[-1].content += "\n" + line
259+
260+
# dedepulicate the messages
261+
for idx, m in enumerate(messages):
262+
double_eot = "<|eot_id|><|eot_id|>"
263+
if double_eot in m.content:
264+
cut_index = m.content.index(double_eot)
265+
messages[idx].content = m.content[:cut_index + (len(double_eot) // 2)]
266+
break
267+
232268
if audio is not None:
233269
# for audio cloning, the first message is a transcript, second is the audio,
234270
# third is the request of what the model should say

nodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2065,7 +2065,6 @@ def expand_image(self, image, left, top, right, bottom, feathering):
20652065
"KSampler": "KSampler",
20662066
"KSamplerAdvanced": "KSampler (Advanced)",
20672067
"AutoRegressiveGeneration": "Autoregressive Generation",
2068-
""
20692068
# Loaders
20702069
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
20712070
"CheckpointLoaderSimple": "Load Checkpoint",

0 commit comments

Comments
 (0)