Skip to content

Commit 01f4512

Browse files
committed
In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying
1 parent d0bd221 commit 01f4512

File tree

4 files changed

+145
-119
lines changed

4 files changed

+145
-119
lines changed

comfy/ldm/flipflop_transformer.py

Lines changed: 84 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,54 +3,22 @@
33
import torch.cuda as cuda
44
import copy
55
from typing import List, Tuple
6-
from dataclasses import dataclass
76

87
import comfy.model_management
98

10-
FLIPFLOP_REGISTRY = {}
11-
12-
def register(name):
13-
def decorator(cls):
14-
FLIPFLOP_REGISTRY[name] = cls
15-
return cls
16-
return decorator
17-
18-
19-
@dataclass
20-
class FlipFlopConfig:
21-
block_name: str
22-
block_wrap_fn: callable
23-
out_names: Tuple[str]
24-
overwrite_forward: str
25-
pinned_staging: bool = False
26-
inference_device: str = "cuda"
27-
offloading_device: str = "cpu"
28-
29-
30-
def patch_model_from_config(model, config: FlipFlopConfig):
31-
block_list = getattr(model, config.block_name)
32-
flip_flop_transformer = FlipFlopTransformer(block_list,
33-
block_wrap_fn=config.block_wrap_fn,
34-
out_names=config.out_names,
35-
offloading_device=config.offloading_device,
36-
inference_device=config.inference_device,
37-
pinned_staging=config.pinned_staging)
38-
delattr(model, config.block_name)
39-
setattr(model, config.block_name, flip_flop_transformer)
40-
setattr(model, config.overwrite_forward, flip_flop_transformer.__call__)
41-
429

4310
class FlipFlopContext:
4411
def __init__(self, holder: FlipFlopHolder):
4512
self.holder = holder
4613
self.reset()
4714

4815
def reset(self):
49-
self.num_blocks = len(self.holder.transformer_blocks)
16+
self.num_blocks = len(self.holder.blocks)
5017
self.first_flip = True
5118
self.first_flop = True
5219
self.last_flip = False
5320
self.last_flop = False
21+
# TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly
5422

5523
def __enter__(self):
5624
self.reset()
@@ -71,9 +39,9 @@ def do_flip(self, func, i: int, _, *args, **kwargs):
7139
next_flop_i = next_flop_i - self.num_blocks
7240
self.last_flip = True
7341
if not self.first_flip:
74-
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
42+
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
7543
if self.last_flip:
76-
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
44+
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
7745
self.first_flip = False
7846
return out
7947

@@ -89,9 +57,9 @@ def do_flop(self, func, i: int, _, *args, **kwargs):
8957
if next_flip_i >= self.num_blocks:
9058
next_flip_i = next_flip_i - self.num_blocks
9159
self.last_flop = True
92-
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
60+
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
9361
if self.last_flop:
94-
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
62+
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
9563
self.first_flop = False
9664
return out
9765

@@ -106,19 +74,20 @@ def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs):
10674

10775

10876
class FlipFlopHolder:
109-
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
110-
self.load_device = torch.device(inference_device)
111-
self.offload_device = torch.device(offloading_device)
112-
self.transformer_blocks = transformer_blocks
77+
def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"):
78+
self.load_device = torch.device(load_device)
79+
self.offload_device = torch.device(offload_device)
80+
self.blocks = blocks
81+
self.flip_amount = flip_amount
11382

11483
self.block_module_size = 0
115-
if len(self.transformer_blocks) > 0:
116-
self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0])
84+
if len(self.blocks) > 0:
85+
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
11786

11887
self.flip: torch.nn.Module = None
11988
self.flop: torch.nn.Module = None
12089
# TODO: make initialization happen in model management code/model patcher, not here
121-
self.initialize_flipflop_blocks(self.load_device)
90+
self.init_flipflop_blocks(self.load_device)
12291

12392
self.compute_stream = cuda.default_stream(self.load_device)
12493
self.cpy_stream = cuda.Stream(self.load_device)
@@ -142,10 +111,57 @@ def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy
142111
def context(self):
143112
return FlipFlopContext(self)
144113

145-
def initialize_flipflop_blocks(self, load_device: torch.device):
146-
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device)
147-
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device)
114+
def init_flipflop_blocks(self, load_device: torch.device):
115+
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
116+
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
117+
118+
def clean_flipflop_blocks(self):
119+
del self.flip
120+
del self.flop
121+
self.flip = None
122+
self.flop = None
123+
124+
125+
class FlopFlopModule(torch.nn.Module):
126+
def __init__(self, block_types: tuple[str, ...]):
127+
super().__init__()
128+
self.block_types = block_types
129+
self.flipflop: dict[str, FlipFlopHolder] = {}
130+
131+
def setup_flipflop_holders(self, block_percentage: float):
132+
for block_type in self.block_types:
133+
if block_type in self.flipflop:
134+
continue
135+
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
136+
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
137+
138+
def clean_flipflop_holders(self):
139+
for block_type in self.flipflop.keys():
140+
self.flipflop[block_type].clean_flipflop_blocks()
141+
del self.flipflop[block_type]
142+
143+
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
144+
if block_type not in self.block_types:
145+
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
146+
if block_type in self.flipflop:
147+
return getattr(self, block_type)[:self.flipflop[block_type].flip_amount]
148+
return getattr(self, block_type)
149+
150+
def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]:
151+
'''
152+
Returns a list of (block_type, size).
153+
If sort_by_size is True, the list is sorted by size.
154+
'''
155+
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
156+
if sort_by_size:
157+
sizes.sort(key=lambda x: x[1])
158+
return sizes
159+
160+
def get_block_module_size(self, block_type: str) -> int:
161+
return comfy.model_management.module_size(getattr(self, block_type)[0])
148162

163+
164+
# Below is the implementation from contentis' prototype flip flop
149165
class FlipFlopTransformer:
150166
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
151167
self.transformer_blocks = transformer_blocks
@@ -379,28 +395,26 @@ def __call__old(self, **feed_dict):
379395
# patch_model_from_config(model, Wan.blocks_config)
380396
# return model
381397

398+
# @register("QwenImageTransformer2DModel")
399+
# class QwenImage:
400+
# @staticmethod
401+
# def qwen_blocks_wrap(block, **kwargs):
402+
# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
403+
# encoder_hidden_states=kwargs["encoder_hidden_states"],
404+
# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
405+
# temb=kwargs["temb"],
406+
# image_rotary_emb=kwargs["image_rotary_emb"],
407+
# transformer_options=kwargs["transformer_options"])
408+
# return kwargs
382409

383-
@register("QwenImageTransformer2DModel")
384-
class QwenImage:
385-
@staticmethod
386-
def qwen_blocks_wrap(block, **kwargs):
387-
kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
388-
encoder_hidden_states=kwargs["encoder_hidden_states"],
389-
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
390-
temb=kwargs["temb"],
391-
image_rotary_emb=kwargs["image_rotary_emb"],
392-
transformer_options=kwargs["transformer_options"])
393-
return kwargs
394-
395-
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
396-
block_wrap_fn=qwen_blocks_wrap,
397-
out_names=("encoder_hidden_states", "hidden_states"),
398-
overwrite_forward="blocks_fwd",
399-
pinned_staging=False)
400-
410+
# blocks_config = FlipFlopConfig(block_name="transformer_blocks",
411+
# block_wrap_fn=qwen_blocks_wrap,
412+
# out_names=("encoder_hidden_states", "hidden_states"),
413+
# overwrite_forward="blocks_fwd",
414+
# pinned_staging=False)
401415

402-
@staticmethod
403-
def patch(model):
404-
patch_model_from_config(model, QwenImage.blocks_config)
405-
return model
406416

417+
# @staticmethod
418+
# def patch(model):
419+
# patch_model_from_config(model, QwenImage.blocks_config)
420+
# return model

comfy/ldm/qwen_image/model.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,36 @@ def __init__(
343343
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
344344

345345
def setup_flipflop_holders(self, block_percentage: float):
346+
if "transformer_blocks" in self.flipflop:
347+
return
348+
import comfy.model_management
346349
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
347350
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
348-
self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
351+
loading = []
352+
for n, m in self.named_modules():
353+
params = []
354+
skip = False
355+
for name, param in m.named_parameters(recurse=False):
356+
params.append(name)
357+
for name, param in m.named_parameters(recurse=True):
358+
if name not in params:
359+
skip = True # skip random weights in non leaf modules
360+
break
361+
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
362+
loading.append((comfy.model_management.module_size(m), n, m, params))
363+
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
349364
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
350365

366+
def clean_flipflop_holders(self):
367+
if "transformer_blocks" in self.flipflop:
368+
self.flipflop["transformer_blocks"].clean_flipflop_blocks()
369+
del self.flipflop["transformer_blocks"]
370+
371+
def get_transformer_blocks(self):
372+
if "transformer_blocks" in self.flipflop:
373+
return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount]
374+
return self.transformer_blocks
375+
351376
def process_img(self, x, index=0, h_offset=0, w_offset=0):
352377
bs, c, t, h, w = x.shape
353378
patch_size = self.patch_size
@@ -409,17 +434,6 @@ def block_wrap(args):
409434

410435
return encoder_hidden_states, hidden_states
411436

412-
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
413-
for i, block in enumerate(self.transformer_blocks):
414-
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
415-
if "blocks_fwd" in self.flipflop:
416-
holder = self.flipflop["blocks_fwd"]
417-
with holder.context() as ctx:
418-
for i, block in enumerate(holder.transformer_blocks):
419-
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
420-
421-
return encoder_hidden_states, hidden_states
422-
423437
def _forward(
424438
self,
425439
x,
@@ -487,12 +501,14 @@ def _forward(
487501
patches = transformer_options.get("patches", {})
488502
blocks_replace = patches_replace.get("dit", {})
489503

490-
encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
491-
encoder_hidden_states=encoder_hidden_states,
492-
encoder_hidden_states_mask=encoder_hidden_states_mask,
493-
temb=temb, image_rotary_emb=image_rotary_emb,
494-
patches=patches, control=control, blocks_replace=blocks_replace, x=x,
495-
transformer_options=transformer_options)
504+
for i, block in enumerate(self.get_transformer_blocks()):
505+
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
506+
if "transformer_blocks" in self.flipflop:
507+
holder = self.flipflop["transformer_blocks"]
508+
with holder.context() as ctx:
509+
for i, block in enumerate(holder.blocks):
510+
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
511+
496512

497513
hidden_states = self.norm_out(hidden_states, temb)
498514
hidden_states = self.proj_out(hidden_states)

comfy/model_patcher.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,27 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
605605
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
606606

607607
def supports_flipflop(self):
608-
return hasattr(self.model.diffusion_model, "flipflop")
608+
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
609+
if not hasattr(self.model, "diffusion_model"):
610+
return False
611+
if not hasattr(self.model.diffusion_model, "flipflop"):
612+
return False
613+
if not comfy.model_management.is_nvidia():
614+
return False
615+
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
616+
return False
617+
return True
618+
619+
def init_flipflop(self):
620+
if not self.supports_flipflop():
621+
return
622+
# figure out how many b
623+
self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"])
624+
625+
def clean_flipflop(self):
626+
if not self.supports_flipflop():
627+
return
628+
self.model.diffusion_model.clean_flipflop_holders()
609629

610630
def _load_list(self):
611631
loading = []
@@ -628,6 +648,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
628648
mem_counter = 0
629649
patch_counter = 0
630650
lowvram_counter = 0
651+
lowvram_mem_counter = 0
652+
if self.supports_flipflop():
653+
...
631654
loading = self._load_list()
632655

633656
load_completely = []
@@ -647,6 +670,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
647670
if mem_counter + module_mem >= lowvram_model_memory:
648671
lowvram_weight = True
649672
lowvram_counter += 1
673+
lowvram_mem_counter += module_mem
650674
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
651675
continue
652676

@@ -709,10 +733,10 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
709733
x[2].to(device_to)
710734

711735
if lowvram_counter > 0:
712-
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
736+
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
713737
self.model.model_lowvram = True
714738
else:
715-
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
739+
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
716740
self.model.model_lowvram = False
717741
if full_load:
718742
self.model.to(device_to)

comfy_extras/nodes_flipflop.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,6 @@
33

44
from comfy_api.latest import ComfyExtension, io
55

6-
from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY
7-
8-
class FlipFlopOld(io.ComfyNode):
9-
@classmethod
10-
def define_schema(cls) -> io.Schema:
11-
return io.Schema(
12-
node_id="FlipFlop",
13-
display_name="FlipFlop (Old)",
14-
category="_for_testing",
15-
inputs=[
16-
io.Model.Input(id="model")
17-
],
18-
outputs=[
19-
io.Model.Output()
20-
],
21-
description="Apply FlipFlop transformation to model using registry-based patching"
22-
)
23-
24-
@classmethod
25-
def execute(cls, model) -> io.NodeOutput:
26-
patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None)
27-
if patch_cls is None:
28-
raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported")
29-
30-
model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model)
31-
32-
return io.NodeOutput(model)
336

347
class FlipFlop(io.ComfyNode):
358
@classmethod
@@ -62,7 +35,6 @@ class FlipFlopExtension(ComfyExtension):
6235
@override
6336
async def get_node_list(self) -> list[type[io.ComfyNode]]:
6437
return [
65-
FlipFlopOld,
6638
FlipFlop,
6739
]
6840

0 commit comments

Comments
 (0)