Skip to content

Commit 57c15f9

Browse files
committed
styling fixes
1 parent 6e9335d commit 57c15f9

File tree

12 files changed

+111
-113
lines changed

12 files changed

+111
-113
lines changed

comfy/autoregressive_sampling.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ def estimate_autoregressive_vram(
1919
max_seq_len: int,
2020
batch_size: int = 1,
2121
dtype = torch.float16,
22-
intermediate_factor: float = 4.0,
22+
intermediate_factor: float = 4.0,
2323
device = torch.device('cuda')
2424
) -> bool:
25-
25+
2626
dtype_size = torch.finfo(dtype).bits // 8
2727
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
2828

29-
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
30-
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
31-
29+
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
30+
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
31+
3232
# rough calculation for activation sizes
3333
intermediate_bytes = intermediate_factor * output_bytes
34-
34+
3535
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
36-
36+
3737
# get vram info
3838
free_vram = get_free_memory(device)
3939
minimum_vram = minimum_inference_memory()
40-
40+
4141
enough_vram = free_vram - minimum_vram >= total_estimated
42-
42+
4343
return enough_vram
4444

4545
class TopKLogits:
@@ -64,7 +64,7 @@ def __init__(self, temperature: float):
6464
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
6565
scores_processed = scores / self.temperature
6666
return scores_processed
67-
67+
6868
class TopPLogitsWarper:
6969
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
7070
top_p = float(top_p)
@@ -175,7 +175,7 @@ def from_model_config(cls, config_dict: dict, **kwargs) -> GenerationConfig:
175175

176176
config_dict = {key: value for key, value in config_dict.items() if value is not None}
177177
valid_fields = {f.name for f in fields(cls)}
178-
178+
179179
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
180180

181181
generation_config = cls(**filtered_args)
@@ -216,7 +216,7 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
216216
self.model.cache_config = self.cache_config
217217

218218
self.kv_caches = {
219-
219+
220220
length: StaticCache(
221221
config=self.cache_config,
222222
max_batch_size = self.cache_config.max_batch,
@@ -234,8 +234,8 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
234234

235235
# cuda graphs only help if input shapes are constant
236236
if (
237-
device == "cuda"
238-
and hasattr(model, "capture_model")
237+
device == "cuda"
238+
and hasattr(model, "capture_model")
239239
and self.model.cache_implementation == "static"
240240
and self.model.use_kv_buckets
241241
and enough_vram
@@ -247,7 +247,7 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
247247
@torch.inference_mode()
248248
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
249249
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
250-
250+
251251
if seed is not None:
252252
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
253253
else:
@@ -335,7 +335,7 @@ def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length:
335335
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
336336
if not hasattr(self.model, "_sample"):
337337
raise ValueError("Model doesn't support AutoRegressive Generation!")
338-
338+
339339
self._prepare_kv_caches()
340340

341341
result = self.model._sample(
@@ -347,7 +347,7 @@ def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length:
347347
)
348348

349349
return result
350-
350+
351351
def _prepare_kv_caches(self):
352352
for kv_cache in self.kv_caches.values():
353353
kv_cache.reset()
@@ -357,13 +357,13 @@ def get_generation_mode(self, config: GenerationConfig):
357357
return GenerationSampling.BEAM_SAMPLING
358358
else:
359359
return GenerationSampling.GREEDY_SEARCH
360-
360+
361361
def _prepare_generated_length(
362362
self,
363363
generation_config: GenerationConfig,
364364
input_ids_length,
365365
):
366-
366+
367367
""" max_length = user_input_id_tokens + generation_max_length """
368368

369369
if generation_config.max_new_length is not None:
@@ -374,11 +374,11 @@ def _prepare_generated_length(
374374
generation_config.min_length = generation_config.min_new_length + input_ids_length
375375

376376
return generation_config
377-
377+
378378
def _get_cache(
379379
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
380380
) -> Cache:
381-
381+
382382
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
383383

384384
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
@@ -412,7 +412,7 @@ def _get_cache(
412412

413413
return self.model._cache
414414

415-
415+
416416
def _prepare_cache_for_generation(
417417
self,
418418
generation_config: GenerationConfig,
@@ -466,7 +466,7 @@ def _prepare_generation_config(self, generation_config: GenerationConfig, **kwar
466466
model_kwargs = generation_config.update(**kwargs)
467467

468468
return generation_config, model_kwargs
469-
469+
470470
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
471471
"""Performs validation related to the resulting generated length"""
472472

@@ -498,7 +498,7 @@ def _validate_generated_length(self, generation_config: GenerationConfig, input_
498498
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
499499
UserWarning,
500500
)
501-
501+
502502
def _expand_inputs_for_generation(
503503
self,
504504
expand_size: int = 1,
@@ -526,13 +526,13 @@ def _expand_dict_for_generation(dict_to_expand):
526526
model_kwargs = _expand_dict_for_generation(model_kwargs)
527527

528528
return input_ids, model_kwargs
529-
529+
530530
def _prepare_special_tokens(
531531
self,
532532
generation_config: GenerationConfig,
533533
device: Optional[Union[torch.device, str]] = None,
534534
):
535-
535+
536536
def _tensor_or_none(token, device=None):
537537
if token is None:
538538
return token
@@ -564,7 +564,7 @@ def _prepare_attention_mask_for_generation(
564564
generation_config: GenerationConfig,
565565
model_kwargs: dict[str, Any],
566566
) -> torch.LongTensor:
567-
567+
568568
pad_token_id = generation_config._pad_token_tensor
569569
eos_token_id = generation_config._eos_token_tensor
570570

@@ -593,12 +593,12 @@ def _prepare_attention_mask_for_generation(
593593
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
594594
)
595595
return attention_mask
596-
596+
597597
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
598598
# to work with BaseModel
599599
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
600600
model = patcher.model.diffusion_model
601-
601+
602602
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
603603
if model.device != patcher.load_device:
604604
model = model.to(patcher.load_device, dtype=model.dtype)
@@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0,
610610
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
611611
else:
612612
main_input_ids = input_ids
613-
613+
614614
device = node._cached_autoregressive_sampler.device
615615

616616
main_input_ids = main_input_ids.to(device)

comfy/ldm/higgsv2/cuda_graph_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def graph(self):
2222

2323
def capture(self, *args, **kwargs):
2424
assert self._graph is None
25-
25+
2626
for _ in range(_NUM_WARMUP_ITERS):
2727
self.model(*args, **kwargs)
2828

comfy/ldm/higgsv2/loudness.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def apply_filter(self, data):
125125
@property
126126
def b_and_a(self):
127127
return self.generate_coefficients()
128-
128+
129129
class Meter(torch.nn.Module):
130130

131131
def __init__(
@@ -227,7 +227,7 @@ def _unfold(self, input_data):
227227
return unfolded
228228

229229
def integrated_loudness(self, data: torch.Tensor):
230-
230+
231231
if not torch.is_tensor(data):
232232
data = torch.from_numpy(data).float()
233233
else:
@@ -291,10 +291,10 @@ def integrated_loudness(self, data: torch.Tensor):
291291

292292
def loudness(
293293
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
294-
):
294+
):
295295
MIN_LOUDNESS = -70
296296
device = audio_data.device
297-
297+
298298
original_length = audio_data.shape[-1]
299299
signal_duration = original_length / sample_rate
300300

comfy/ldm/higgsv2/model.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from typing import Optional, Tuple, Union, List
2020

2121
class GenerationMode(Enum):
22-
TEXT = 0
23-
AUDIO_INIT = 1
22+
TEXT = 0
23+
AUDIO_INIT = 1
2424
AUDIO_IN_PROGRESS = 2
2525

2626
def _ignore_causal_mask_sdpa(
@@ -413,7 +413,7 @@ class HiggsAudioModel(nn.Module):
413413

414414
def __init__(self, device = None, dtype = None, operations = None, **kwargs):
415415
super().__init__()
416-
416+
417417
self.padding_idx = kwargs["pad_token_id"]
418418
self.audio_in_token_idx = kwargs["audio_in_token_idx"]
419419
self.audio_out_token_idx = kwargs["audio_out_token_idx"]
@@ -439,7 +439,7 @@ def __init__(self, device = None, dtype = None, operations = None, **kwargs):
439439

440440
self.audio_out_bos_token_id = 128013
441441
self.audio_eos_token_id = 128012
442-
442+
443443
text_config = kwargs["text_config"]
444444
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
445445
num_key_value_heads = text_config["num_key_value_heads"],
@@ -616,7 +616,7 @@ def _sample_text_tokens(
616616
next_audio_tokens = None
617617

618618
return next_tokens, next_audio_tokens
619-
619+
620620
def _update_causal_mask(
621621
self,
622622
attention_mask: torch.Tensor,
@@ -677,7 +677,7 @@ def _update_causal_mask(
677677
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
678678

679679
return causal_mask
680-
680+
681681
def _embed_audio_ids(self, audio_ids):
682682
codebook_shift = (
683683
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
@@ -712,7 +712,7 @@ def _prepare_all_static_kv_cache_masks(self, hidden_states, attention_mask, audi
712712
)
713713
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
714714
return fast_forward_attention_mask, audio_attention_mask
715-
715+
716716
def _forward_core(
717717
self,
718718
hidden_states: torch.Tensor,
@@ -728,7 +728,7 @@ def _forward_core(
728728
is_using_cuda_graph: Optional[bool] = False,
729729
):
730730

731-
position_id_offset = cache_position[0] if use_cache else 0
731+
position_id_offset = cache_position[0] if use_cache else 0
732732
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
733733

734734
for decoder_layer in self.layers:
@@ -927,7 +927,7 @@ def forward(
927927
)
928928

929929
return ret
930-
930+
931931
def _update_model_kwargs_for_generation(
932932
self,
933933
outputs,
@@ -956,13 +956,13 @@ def _update_model_kwargs_for_generation(
956956
)
957957

958958
return model_kwargs
959-
959+
960960
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
961961
from_cache_size = from_cache.get_max_cache_shape()
962962
assert to_cache.get_max_cache_shape() >= from_cache_size, (
963963
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
964964
)
965-
965+
966966
n_layers = self.num_hidden_layers
967967

968968
for i in range(n_layers):
@@ -977,7 +977,7 @@ def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
977977
self.cache_config.head_dim),
978978
device=self.device, dtype=self.dtype
979979
)
980-
980+
981981
if getattr(to_layer, "values", None) is None:
982982
to_layer.values = torch.zeros(
983983
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
@@ -1011,7 +1011,7 @@ def _prepare_kv_cache(
10111011
f"The current sequence length {current_sequence_length} is larger than "
10121012
f"all past key values buckets {past_key_values_buckets.keys()}."
10131013
)
1014-
1014+
10151015
def _sample(
10161016
self,
10171017
input_ids: torch.LongTensor,
@@ -1020,7 +1020,7 @@ def _sample(
10201020
past_key_values_buckets: Optional[OrderedDict[int, Cache]],
10211021
**model_kwargs,
10221022
):
1023-
1023+
10241024
# code supports only non-mixed batchs
10251025

10261026
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
@@ -1069,7 +1069,7 @@ def _sample(
10691069

10701070
while not this_peer_finished:
10711071
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
1072-
1072+
10731073
if input_ids[0][-1] == audio_out_bos_token_id:
10741074
generation_mode = GenerationMode.AUDIO_INIT
10751075
elif input_ids[0][-1] == self.audio_out_token_idx:
@@ -1211,7 +1211,7 @@ def _sample(
12111211
pbar.update(pbar.total - pbar.current)
12121212

12131213
return audio_sequences
1214-
1214+
12151215
@torch.inference_mode()
12161216
def generate(
12171217
self,
@@ -1222,7 +1222,7 @@ def generate(
12221222
generation_functions = None,
12231223
**kwargs,
12241224
):
1225-
1225+
12261226
if generation_config is None:
12271227
generation_config = GenerationConfig()
12281228

0 commit comments

Comments
 (0)