1414
1515import  inspect 
1616from  dataclasses  import  dataclass , field 
17- from  typing  import  Any , Dict , List , Tuple , Union 
17+ from  typing  import  Any , Dict , List , Tuple , Union , Type 
18+ from  collections  import  OrderedDict 
1819
1920import  torch 
2021from  tqdm .auto  import  tqdm 
3031from  .pipeline_loading_utils  import  _fetch_class_library_tuple , _get_pipeline_class 
3132from  .pipeline_utils  import  DiffusionPipeline 
3233
34+ import  warnings 
35+ 
3336
3437if  is_accelerate_available ():
3538    import  accelerate 
@@ -99,6 +102,7 @@ class PipelineBlock:
99102    optional_components  =  []
100103    required_components  =  []
101104    required_auxiliaries  =  []
105+     optional_auxiliaries  =  []
102106
103107    @property  
104108    def  inputs (self ) ->  Tuple [Tuple [str , Any ], ...]:
@@ -122,7 +126,7 @@ def __init__(self, **kwargs):
122126        for  key , value  in  kwargs .items ():
123127            if  key  in  self .required_components  or  key  in  self .optional_components :
124128                self .components [key ] =  value 
125-             elif  key  in  self .required_auxiliaries :
129+             elif  key  in  self .required_auxiliaries   or   key   in   self . optional_auxiliaries :
126130                self .auxiliaries [key ] =  value 
127131            else :
128132                self .configs [key ] =  value 
@@ -152,10 +156,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
152156                components_to_add [component_name ] =  component 
153157
154158        # add auxiliaries 
159+         expected_auxiliaries  =  set (cls .required_auxiliaries  +  cls .optional_auxiliaries )
155160        # - auxiliaries that are passed in kwargs 
156-         auxiliaries_to_add  =  {k : kwargs .pop (k ) for  k  in  cls . required_auxiliaries  if  k  in  kwargs }
161+         auxiliaries_to_add  =  {k : kwargs .pop (k ) for  k  in  expected_auxiliaries  if  k  in  kwargs }
157162        # - auxiliaries that are in the pipeline 
158-         for  aux_name  in  cls . required_auxiliaries :
163+         for  aux_name  in  expected_auxiliaries :
159164            if  hasattr (pipe , aux_name ) and  aux_name  not  in auxiliaries_to_add :
160165                auxiliaries_to_add [aux_name ] =  getattr (pipe , aux_name )
161166        block_kwargs  =  {** components_to_add , ** auxiliaries_to_add }
@@ -167,7 +172,7 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
167172        expected_configs  =  {
168173            k 
169174            for  k  in  pipe .config .keys ()
170-             if  k  in  init_params  and  k  not  in expected_components  and  k  not  in cls . required_auxiliaries 
175+             if  k  in  init_params  and  k  not  in expected_components  and  k  not  in expected_auxiliaries 
171176        }
172177
173178        for  config_name  in  expected_configs :
@@ -210,6 +215,188 @@ def __repr__(self):
210215        )
211216
212217
218+ def  combine_inputs (* input_lists : List [Tuple [str , Any ]]) ->  List [Tuple [str , Any ]]:
219+     """ 
220+     Combines multiple lists of (name, default_value) tuples. 
221+     For duplicate inputs, updates only if current value is None and new value is not None. 
222+     Warns if multiple non-None default values exist for the same input. 
223+     """ 
224+     combined_dict  =  {}
225+     for  inputs  in  input_lists :
226+         for  name , value  in  inputs :
227+             if  name  in  combined_dict :
228+                 current_value  =  combined_dict [name ]
229+                 if  current_value  is  not None  and  value  is  not None  and  current_value  !=  value :
230+                     warnings .warn (
231+                         f"Multiple different default values found for input '{ name }  
232+                         f"{ current_value } { value } { current_value }  
233+                     )
234+                 if  current_value  is  None  and  value  is  not None :
235+                     combined_dict [name ] =  value 
236+             else :
237+                 combined_dict [name ] =  value 
238+     return  list (combined_dict .items ())
239+ 
240+ 
241+ 
242+ class  AutoStep (PipelineBlock ):
243+     base_blocks  =  []     # list of block classes 
244+     trigger_inputs  =  []  # list of trigger inputs (None for default block) 
245+     required_components  =  []
246+     optional_components  =  []
247+     required_auxiliaries  =  []
248+     optional_auxiliaries  =  []
249+     
250+     def  __init__ (self , ** kwargs ):
251+         self .blocks  =  []
252+         
253+         for  block_cls , trigger  in  zip (self .base_blocks , self .trigger_inputs ):
254+             # Check components 
255+             missing_components  =  [
256+                 component  for  component  in  block_cls .required_components  
257+                 if  component  not  in kwargs 
258+             ]
259+             
260+             # Check auxiliaries 
261+             missing_auxiliaries  =  [
262+                 auxiliary  for  auxiliary  in  block_cls .required_auxiliaries  
263+                 if  auxiliary  not  in kwargs 
264+             ]
265+             
266+             if  not  missing_components  and  not  missing_auxiliaries :
267+                 # Only get kwargs that the block's __init__ accepts 
268+                 block_params  =  inspect .signature (block_cls .__init__ ).parameters 
269+                 block_kwargs  =  {
270+                     k : v  for  k , v  in  kwargs .items () 
271+                     if  k  in  block_params 
272+                 }
273+                 self .blocks .append (block_cls (** block_kwargs ))
274+                 
275+                 # Print message about trigger condition 
276+                 if  trigger  is  None :
277+                     print (f"Added default block: { block_cls .__name__ }  )
278+                 else :
279+                     print (f"Added block { block_cls .__name__ } { trigger }  )
280+             else :
281+                 if  trigger  is  None :
282+                     print (f"Cannot add default block { block_cls .__name__ }  )
283+                 else :
284+                     print (f"Cannot add block { block_cls .__name__ } { trigger }  )
285+                 if  missing_components :
286+                     print (f"  - Missing components: { missing_components }  )
287+                 if  missing_auxiliaries :
288+                     print (f"  - Missing auxiliaries: { missing_auxiliaries }  )
289+     
290+     @property  
291+     def  components (self ):
292+         # Combine components from all blocks 
293+         components  =  {}
294+         for  block  in  self .blocks :
295+             components .update (block .components )
296+         return  components 
297+     
298+     @property  
299+     def  auxiliaries (self ):
300+         # Combine auxiliaries from all blocks 
301+         auxiliaries  =  {}
302+         for  block  in  self .blocks :
303+             auxiliaries .update (block .auxiliaries )
304+         return  auxiliaries 
305+     
306+     @property  
307+     def  configs (self ):
308+         # Combine configs from all blocks 
309+         configs  =  {}
310+         for  block  in  self .blocks :
311+             configs .update (block .configs )
312+         return  configs 
313+     
314+     @property  
315+     def  inputs (self ) ->  List [Tuple [str , Any ]]:
316+         return  combine_inputs (* (block .inputs  for  block  in  self .blocks ))
317+     
318+     @property  
319+     def  intermediates_inputs (self ) ->  List [str ]:
320+         return  list (set ().union (* (
321+             block .intermediates_inputs  for  block  in  self .blocks 
322+         )))
323+     
324+     @property  
325+     def  intermediates_outputs (self ) ->  List [str ]:
326+         return  list (set ().union (* (
327+             block .intermediates_outputs  for  block  in  self .blocks 
328+         )))
329+     
330+     def  __call__ (self , pipeline , state ):
331+         # Check triggers in priority order 
332+         for  idx , trigger  in  enumerate (self .trigger_inputs [:- 1 ]):  # Skip last (None) trigger 
333+             if  state .get_input (trigger ) is  not None :
334+                 return  self .blocks [idx ](pipeline , state )
335+         # If no triggers match, use the default block (last one) 
336+         return  self .blocks [- 1 ](pipeline , state )
337+ 
338+ 
339+ def  make_auto_step (pipeline_block_map : OrderedDict ) ->  Type [AutoStep ]:
340+     """ 
341+     Creates a new AutoStep subclass with updated class attributes based on the pipeline block map. 
342+      
343+     Args: 
344+         pipeline_block_map: OrderedDict mapping trigger inputs to pipeline block classes. 
345+                           Order determines priority (earlier entries take precedence). 
346+                           Must include None key for the default block. 
347+     """ 
348+     blocks  =  list (pipeline_block_map .values ())
349+     triggers  =  list (pipeline_block_map .keys ())
350+     
351+     # Get all expected components (either required or optional by any block) 
352+     expected_components  =  []
353+     for  block  in  blocks :
354+         for  component  in  (block .required_components  +  block .optional_components ):
355+             if  component  not  in expected_components :
356+                 expected_components .append (component )
357+     
358+     # A component is required if it's in required_components of all blocks 
359+     required_components  =  [
360+         component  for  component  in  expected_components 
361+         if  all (component  in  block .required_components  for  block  in  blocks )
362+     ]
363+     
364+     # All other expected components are optional 
365+     optional_components  =  [
366+         component  for  component  in  expected_components 
367+         if  component  not  in required_components 
368+     ]
369+ 
370+     # Get all expected auxiliaries (either required or optional by any block) 
371+     expected_auxiliaries  =  []
372+     for  block  in  blocks :
373+         for  auxiliary  in  (block .required_auxiliaries  +  getattr (block , 'optional_auxiliaries' , [])):
374+             if  auxiliary  not  in expected_auxiliaries :
375+                 expected_auxiliaries .append (auxiliary )
376+     
377+     # An auxiliary is required if it's in required_auxiliaries of all blocks 
378+     required_auxiliaries  =  [
379+         auxiliary  for  auxiliary  in  expected_auxiliaries 
380+         if  all (auxiliary  in  block .required_auxiliaries  for  block  in  blocks )
381+     ]
382+     
383+     # All other expected auxiliaries are optional 
384+     optional_auxiliaries  =  [
385+         auxiliary  for  auxiliary  in  expected_auxiliaries 
386+         if  auxiliary  not  in required_auxiliaries 
387+     ]
388+ 
389+     # Create new class with updated attributes 
390+     return  type ('AutoStep' , (AutoStep ,), {
391+         'base_blocks' : blocks ,
392+         'trigger_inputs' : triggers ,
393+         'required_components' : required_components ,
394+         'optional_components' : optional_components ,
395+         'required_auxiliaries' : required_auxiliaries ,
396+         'optional_auxiliaries' : optional_auxiliaries ,
397+     })
398+ 
399+ 
213400class  ModularPipelineBuilder (ConfigMixin ):
214401    """ 
215402    Base class for all Modular pipelines. 
@@ -585,7 +772,7 @@ def from_pipe(cls, pipeline, **kwargs):
585772        # Create each block, passing only unused items that the block expects 
586773        for  block_class  in  modular_pipeline_class .default_pipeline_blocks :
587774            expected_components  =  set (block_class .required_components  +  block_class .optional_components )
588-             expected_auxiliaries  =  set (block_class .required_auxiliaries )
775+             expected_auxiliaries  =  set (block_class .required_auxiliaries   +   block_class . optional_auxiliaries )
589776
590777            # Get init parameters to check for expected configs 
591778            init_params  =  inspect .signature (block_class .__init__ ).parameters 
0 commit comments