Skip to content

Commit 2ac8999

Browse files
committed
final
1 parent 6412422 commit 2ac8999

File tree

3 files changed

+3
-26
lines changed

3 files changed

+3
-26
lines changed

comfy/autoregressive_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
255255
self.dtype = model.dtype
256256

257257
self.model = model
258+
self.model.generation_config = GenerationConfig.from_model_config(self.model.config)
259+
self.model.generation_config.cache_implementation = self.model.cache_implementation
258260

259261
text_config = self.model.cache_config
260262
self.cache_config = CacheConfig(
@@ -331,8 +333,6 @@ def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length:
331333
do_sample = do_sample,
332334
temperature = temperature)
333335

334-
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
335-
336336
generation_config, model_kwargs = self._prepare_generation_config(
337337
generation_config, **kwargs
338338
)

comfy/ldm/higgsv2/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,14 +428,13 @@ def __init__(self, device = None, dtype = None, operations = None, **kwargs):
428428
self.cache_config = kwargs["text_config"]
429429
self.hidden_dim = kwargs["text_config"]["hidden_size"]
430430
self.max_seq_len = kwargs["text_config"]["max_position_embeddings"]
431+
self.cache_implementation = "static"
431432
self.use_kv_buckets = kwargs.get("use_kv_buckets", False)
432433

433434
self.dtype = dtype
434435
self.device = device
435436
self.config = kwargs
436437

437-
self.generation_config = GenerationConfig.from_model_config(kwargs)
438-
self.generation_config.cache_implementation = self.cache_implementation = "static"
439438

440439
self.audio_out_bos_token_id = 128013
441440
self.audio_eos_token_id = 128012

comfy_extras/nodes_audio.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ def convert_to_ml_format(self, clip, text, audio=None):
174174
current_role = None
175175
collecting_system = False
176176
system_buffer = []
177-
collecting_instruction = False
178-
instruction_buffer = []
179177

180178
for line in lines:
181179
line = line.strip()
@@ -198,26 +196,6 @@ def convert_to_ml_format(self, clip, text, audio=None):
198196
collecting_system = False
199197
continue
200198

201-
# generation instruction start
202-
if "<|generation_instruction_start|>" in line:
203-
collecting_instruction = True
204-
instruction_buffer = []
205-
continue
206-
207-
if collecting_instruction:
208-
if "<|generation_instruction_end|>" in line:
209-
instruction_text = "\n".join(instruction_buffer)
210-
# include both start and end tokens
211-
messages.append(Message(
212-
role="user",
213-
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
214-
))
215-
instruction_buffer = []
216-
collecting_instruction = False
217-
else:
218-
instruction_buffer.append(line)
219-
continue
220-
221199
# speaker lines SPEAKER-0: text
222200
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
223201
if match:

0 commit comments

Comments
 (0)