33import torch .cuda as cuda
44import copy
55from typing import List , Tuple
6- from dataclasses import dataclass
76
87import comfy .model_management
98
10- FLIPFLOP_REGISTRY = {}
11-
12- def register (name ):
13- def decorator (cls ):
14- FLIPFLOP_REGISTRY [name ] = cls
15- return cls
16- return decorator
17-
18-
19- @dataclass
20- class FlipFlopConfig :
21- block_name : str
22- block_wrap_fn : callable
23- out_names : Tuple [str ]
24- overwrite_forward : str
25- pinned_staging : bool = False
26- inference_device : str = "cuda"
27- offloading_device : str = "cpu"
28-
29-
30- def patch_model_from_config (model , config : FlipFlopConfig ):
31- block_list = getattr (model , config .block_name )
32- flip_flop_transformer = FlipFlopTransformer (block_list ,
33- block_wrap_fn = config .block_wrap_fn ,
34- out_names = config .out_names ,
35- offloading_device = config .offloading_device ,
36- inference_device = config .inference_device ,
37- pinned_staging = config .pinned_staging )
38- delattr (model , config .block_name )
39- setattr (model , config .block_name , flip_flop_transformer )
40- setattr (model , config .overwrite_forward , flip_flop_transformer .__call__ )
41-
429
4310class FlipFlopContext :
4411 def __init__ (self , holder : FlipFlopHolder ):
4512 self .holder = holder
4613 self .reset ()
4714
4815 def reset (self ):
49- self .num_blocks = len (self .holder .transformer_blocks )
16+ self .num_blocks = len (self .holder .blocks )
5017 self .first_flip = True
5118 self .first_flop = True
5219 self .last_flip = False
5320 self .last_flop = False
21+ # TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly
5422
5523 def __enter__ (self ):
5624 self .reset ()
@@ -71,9 +39,9 @@ def do_flip(self, func, i: int, _, *args, **kwargs):
7139 next_flop_i = next_flop_i - self .num_blocks
7240 self .last_flip = True
7341 if not self .first_flip :
74- self .holder ._copy_state_dict (self .holder .flop .state_dict (), self .holder .transformer_blocks [next_flop_i ].state_dict (), self .holder .event_flop , self .holder .cpy_end_event )
42+ self .holder ._copy_state_dict (self .holder .flop .state_dict (), self .holder .blocks [next_flop_i ].state_dict (), self .holder .event_flop , self .holder .cpy_end_event )
7543 if self .last_flip :
76- self .holder ._copy_state_dict (self .holder .flip .state_dict (), self .holder .transformer_blocks [0 ].state_dict (), cpy_start_event = self .holder .event_flip )
44+ self .holder ._copy_state_dict (self .holder .flip .state_dict (), self .holder .blocks [0 ].state_dict (), cpy_start_event = self .holder .event_flip )
7745 self .first_flip = False
7846 return out
7947
@@ -89,9 +57,9 @@ def do_flop(self, func, i: int, _, *args, **kwargs):
8957 if next_flip_i >= self .num_blocks :
9058 next_flip_i = next_flip_i - self .num_blocks
9159 self .last_flop = True
92- self .holder ._copy_state_dict (self .holder .flip .state_dict (), self .holder .transformer_blocks [next_flip_i ].state_dict (), self .holder .event_flip , self .holder .cpy_end_event )
60+ self .holder ._copy_state_dict (self .holder .flip .state_dict (), self .holder .blocks [next_flip_i ].state_dict (), self .holder .event_flip , self .holder .cpy_end_event )
9361 if self .last_flop :
94- self .holder ._copy_state_dict (self .holder .flop .state_dict (), self .holder .transformer_blocks [1 ].state_dict (), cpy_start_event = self .holder .event_flop )
62+ self .holder ._copy_state_dict (self .holder .flop .state_dict (), self .holder .blocks [1 ].state_dict (), cpy_start_event = self .holder .event_flop )
9563 self .first_flop = False
9664 return out
9765
@@ -106,19 +74,20 @@ def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs):
10674
10775
10876class FlipFlopHolder :
109- def __init__ (self , transformer_blocks : List [torch .nn .Module ], inference_device = "cuda" , offloading_device = "cpu" ):
110- self .load_device = torch .device (inference_device )
111- self .offload_device = torch .device (offloading_device )
112- self .transformer_blocks = transformer_blocks
77+ def __init__ (self , blocks : List [torch .nn .Module ], flip_amount : int , load_device = "cuda" , offload_device = "cpu" ):
78+ self .load_device = torch .device (load_device )
79+ self .offload_device = torch .device (offload_device )
80+ self .blocks = blocks
81+ self .flip_amount = flip_amount
11382
11483 self .block_module_size = 0
115- if len (self .transformer_blocks ) > 0 :
116- self .block_module_size = comfy .model_management .module_size (self .transformer_blocks [0 ])
84+ if len (self .blocks ) > 0 :
85+ self .block_module_size = comfy .model_management .module_size (self .blocks [0 ])
11786
11887 self .flip : torch .nn .Module = None
11988 self .flop : torch .nn .Module = None
12089 # TODO: make initialization happen in model management code/model patcher, not here
121- self .initialize_flipflop_blocks (self .load_device )
90+ self .init_flipflop_blocks (self .load_device )
12291
12392 self .compute_stream = cuda .default_stream (self .load_device )
12493 self .cpy_stream = cuda .Stream (self .load_device )
@@ -142,10 +111,57 @@ def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy
142111 def context (self ):
143112 return FlipFlopContext (self )
144113
145- def initialize_flipflop_blocks (self , load_device : torch .device ):
146- self .flip = copy .deepcopy (self .transformer_blocks [0 ]).to (device = load_device )
147- self .flop = copy .deepcopy (self .transformer_blocks [1 ]).to (device = load_device )
114+ def init_flipflop_blocks (self , load_device : torch .device ):
115+ self .flip = copy .deepcopy (self .blocks [0 ]).to (device = load_device )
116+ self .flop = copy .deepcopy (self .blocks [1 ]).to (device = load_device )
117+
118+ def clean_flipflop_blocks (self ):
119+ del self .flip
120+ del self .flop
121+ self .flip = None
122+ self .flop = None
123+
124+
125+ class FlopFlopModule (torch .nn .Module ):
126+ def __init__ (self , block_types : tuple [str , ...]):
127+ super ().__init__ ()
128+ self .block_types = block_types
129+ self .flipflop : dict [str , FlipFlopHolder ] = {}
130+
131+ def setup_flipflop_holders (self , block_percentage : float ):
132+ for block_type in self .block_types :
133+ if block_type in self .flipflop :
134+ continue
135+ num_blocks = int (len (self .transformer_blocks ) * (1.0 - block_percentage ))
136+ self .flipflop ["transformer_blocks" ] = FlipFlopHolder (self .transformer_blocks [num_blocks :], num_blocks )
137+
138+ def clean_flipflop_holders (self ):
139+ for block_type in self .flipflop .keys ():
140+ self .flipflop [block_type ].clean_flipflop_blocks ()
141+ del self .flipflop [block_type ]
142+
143+ def get_blocks (self , block_type : str ) -> torch .nn .ModuleList :
144+ if block_type not in self .block_types :
145+ raise ValueError (f"Block type { block_type } not found in { self .block_types } " )
146+ if block_type in self .flipflop :
147+ return getattr (self , block_type )[:self .flipflop [block_type ].flip_amount ]
148+ return getattr (self , block_type )
149+
150+ def get_all_block_module_sizes (self , sort_by_size : bool = False ) -> list [tuple [str , int ]]:
151+ '''
152+ Returns a list of (block_type, size).
153+ If sort_by_size is True, the list is sorted by size.
154+ '''
155+ sizes = [(block_type , self .get_block_module_size (block_type )) for block_type in self .block_types ]
156+ if sort_by_size :
157+ sizes .sort (key = lambda x : x [1 ])
158+ return sizes
159+
160+ def get_block_module_size (self , block_type : str ) -> int :
161+ return comfy .model_management .module_size (getattr (self , block_type )[0 ])
148162
163+
164+ # Below is the implementation from contentis' prototype flip flop
149165class FlipFlopTransformer :
150166 def __init__ (self , transformer_blocks : List [torch .nn .Module ], block_wrap_fn , out_names : Tuple [str ], pinned_staging : bool = False , inference_device = "cuda" , offloading_device = "cpu" ):
151167 self .transformer_blocks = transformer_blocks
@@ -379,28 +395,26 @@ def __call__old(self, **feed_dict):
379395# patch_model_from_config(model, Wan.blocks_config)
380396# return model
381397
398+ # @register("QwenImageTransformer2DModel")
399+ # class QwenImage:
400+ # @staticmethod
401+ # def qwen_blocks_wrap(block, **kwargs):
402+ # kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
403+ # encoder_hidden_states=kwargs["encoder_hidden_states"],
404+ # encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
405+ # temb=kwargs["temb"],
406+ # image_rotary_emb=kwargs["image_rotary_emb"],
407+ # transformer_options=kwargs["transformer_options"])
408+ # return kwargs
382409
383- @register ("QwenImageTransformer2DModel" )
384- class QwenImage :
385- @staticmethod
386- def qwen_blocks_wrap (block , ** kwargs ):
387- kwargs ["encoder_hidden_states" ], kwargs ["hidden_states" ] = block (hidden_states = kwargs ["hidden_states" ],
388- encoder_hidden_states = kwargs ["encoder_hidden_states" ],
389- encoder_hidden_states_mask = kwargs ["encoder_hidden_states_mask" ],
390- temb = kwargs ["temb" ],
391- image_rotary_emb = kwargs ["image_rotary_emb" ],
392- transformer_options = kwargs ["transformer_options" ])
393- return kwargs
394-
395- blocks_config = FlipFlopConfig (block_name = "transformer_blocks" ,
396- block_wrap_fn = qwen_blocks_wrap ,
397- out_names = ("encoder_hidden_states" , "hidden_states" ),
398- overwrite_forward = "blocks_fwd" ,
399- pinned_staging = False )
400-
410+ # blocks_config = FlipFlopConfig(block_name="transformer_blocks",
411+ # block_wrap_fn=qwen_blocks_wrap,
412+ # out_names=("encoder_hidden_states", "hidden_states"),
413+ # overwrite_forward="blocks_fwd",
414+ # pinned_staging=False)
401415
402- @staticmethod
403- def patch (model ):
404- patch_model_from_config (model , QwenImage .blocks_config )
405- return model
406416
417+ # @staticmethod
418+ # def patch(model):
419+ # patch_model_from_config(model, QwenImage.blocks_config)
420+ # return model
0 commit comments