@@ -197,14 +197,38 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]])
197197    return  list (combined_dict .items ())
198198
199199
200- class  MultiPipelineBlocks :
200+ class  AutoPipelineBlocks :
201201    """ 
202-     A class that combines multiple pipeline block classes into one. When used, it has same API and properties as 
203-     PipelineBlock. And it can be used in ModularPipeline as a single pipeline block. 
202+     A class that automatically selects a block to run based on the inputs. 
203+ 
204+     Attributes: 
205+         block_classes: List of block classes to be used 
206+         block_names: List of prefixes for each block 
207+         block_trigger_inputs: List of input names that trigger specific blocks, with None for default 
204208    """ 
205209
206210    block_classes  =  []
207-     block_prefixes  =  []
211+     block_names  =  []
212+     block_trigger_inputs  =  []
213+ 
214+     def  __init__ (self ):
215+         blocks  =  OrderedDict ()
216+         for  block_name , block_cls  in  zip (self .block_names , self .block_classes ):
217+             blocks [block_name ] =  block_cls ()
218+         self .blocks  =  blocks 
219+         if  not  (len (self .block_classes ) ==  len (self .block_names ) ==  len (self .block_trigger_inputs )):
220+             raise  ValueError (f"In { self .__class__ .__name__ }  )
221+         default_blocks  =  [t  for  t  in  self .block_trigger_inputs  if  t  is  None ]
222+         if  len (default_blocks ) >  1  or  (
223+                 len (default_blocks ) ==  1  and  self .block_trigger_inputs [- 1 ] is  not None 
224+             ):
225+             raise  ValueError (
226+                 f"In { self .__class__ .__name__ }  
227+                 "in block_trigger_inputs." 
228+             )
229+ 
230+         # Map trigger inputs to block objects 
231+         self .trigger_to_block_map  =  dict (zip (self .block_trigger_inputs , self .blocks .values ()))
208232
209233    @property  
210234    def  model_name (self ):
@@ -228,13 +252,6 @@ def expected_configs(self):
228252                    expected_configs .append (config )
229253        return  expected_configs 
230254
231-     def  __init__ (self ):
232-         blocks  =  OrderedDict ()
233-         for  block_prefix , block_cls  in  zip (self .block_prefixes , self .block_classes ):
234-             block_name  =  f"{ block_prefix }   if  block_prefix  !=  ""  else  "step" 
235-             blocks [block_name ] =  block_cls ()
236-         self .blocks  =  blocks 
237- 
238255    # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc 
239256    @property  
240257    def  components (self ):
@@ -265,60 +282,6 @@ def configs(self):
265282            configs .update (block .configs )
266283        return  configs 
267284
268-     @property  
269-     def  inputs (self ) ->  List [Tuple [str , Any ]]:
270-         raise  NotImplementedError ("inputs property must be implemented in subclasses" )
271- 
272-     @property  
273-     def  intermediates_inputs (self ) ->  List [str ]:
274-         raise  NotImplementedError ("intermediates_inputs property must be implemented in subclasses" )
275- 
276-     @property  
277-     def  intermediates_outputs (self ) ->  List [str ]:
278-         raise  NotImplementedError ("intermediates_outputs property must be implemented in subclasses" )
279- 
280-     def  __call__ (self , pipeline , state ):
281-         raise  NotImplementedError ("__call__ method must be implemented in subclasses" )
282- 
283- 
284- 
285- 
286- # YiYi TODO: remove the trigger input logic and keep it more flexible and less convenient: 
287- # user will need to explicitly write the dispatch logic in __call__ for each subclass of this 
288- class  AutoPipelineBlocks (MultiPipelineBlocks ):
289-     """ 
290-     A class that automatically selects which block to run based on trigger inputs. 
291- 
292-     Attributes: 
293-         block_classes: List of block classes to be used 
294-         block_prefixes: List of prefixes for each block 
295-         block_trigger_inputs: List of input names that trigger specific blocks, with None for default 
296-     """ 
297- 
298-     block_classes  =  []
299-     block_prefixes  =  []
300-     block_trigger_inputs  =  []
301- 
302-     def  __init__ (self ):
303-         super ().__init__ ()
304-         self .__post_init__ ()
305- 
306-     def  __post_init__ (self ):
307-         """ 
308-         Create mapping of trigger inputs directly to block objects. Validates that there is at most one default block 
309-         (None trigger). 
310-         """ 
311-         # Check for at most one default block 
312-         default_blocks  =  [t  for  t  in  self .block_trigger_inputs  if  t  is  None ]
313-         if  len (default_blocks ) >  1 :
314-             raise  ValueError (
315-                 f"Multiple default blocks specified in { self .__class__ .__name__ }  
316-                 "Must include at most one None in block_trigger_inputs." 
317-             )
318- 
319-         # Map trigger inputs to block objects 
320-         self .trigger_to_block_map  =  dict (zip (self .block_trigger_inputs , self .blocks .values ()))
321- 
322285    @property  
323286    def  inputs (self ) ->  List [Tuple [str , Any ]]:
324287        named_inputs  =  [(name , block .inputs ) for  name , block  in  self .blocks .items ()]
@@ -335,30 +298,15 @@ def intermediates_outputs(self) -> List[str]:
335298    @torch .no_grad () 
336299    def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
337300        # Find default block first (if any) 
338-         default_block  =  self .trigger_to_block_map .get (None )
339- 
340-         # Check which trigger inputs are present 
341-         active_triggers  =  [
342-             input_name 
343-             for  input_name  in  self .block_trigger_inputs 
344-             if  input_name  is  not None  and  state .get_input (input_name ) is  not None 
345-         ]
346- 
347-         # If multiple triggers are active, raise error 
348-         if  len (active_triggers ) >  1 :
349-             trigger_names  =  [f"'{ t }   for  t  in  active_triggers ]
350-             raise  ValueError (
351-                 f"Multiple trigger inputs found ({ ', ' .join (trigger_names )}  
352-                 f"Only one trigger input can be provided for { self .__class__ .__name__ }  
353-             )
354301
355-         # Get the  block to run (use default if no triggers active )
356-         block   =   self .trigger_to_block_map . get ( active_triggers [ 0 ])  if   active_triggers   else   default_block 
357-         if  block  is  None :
358-             logger . warning ( f"No valid block found in  { self .__class__ . __name__ } , skipping." ) 
359-             return   pipeline ,  state 
302+         block  =   self . trigger_to_block_map . get ( None )
303+         for   input_name   in   self .block_trigger_inputs : 
304+              if  input_name  is   not   None   and   state . get_input ( input_name )  is   not None :
305+                  block   =   self .trigger_to_block_map [ input_name ] 
306+                  break 
360307
361308        try :
309+             logger .info (f"Running block: { block .__class__ .__name__ } { input_name }  )
362310            return  block (pipeline , state )
363311        except  Exception  as  e :
364312            error_msg  =  (
@@ -440,10 +388,70 @@ def __repr__(self):
440388        )
441389
442390
443- class  SequentialPipelineBlocks ( MultiPipelineBlocks ) :
391+ class  SequentialPipelineBlocks :
444392    """ 
445393    A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. 
446394    """ 
395+     block_classes  =  []
396+     block_names  =  []
397+ 
398+     @property  
399+     def  model_name (self ):
400+         return  next (iter (self .blocks .values ())).model_name 
401+ 
402+     @property  
403+     def  expected_components (self ):
404+         expected_components  =  []
405+         for  block  in  self .blocks .values ():
406+             for  component  in  block .expected_components :
407+                 if  component  not  in expected_components :
408+                     expected_components .append (component )
409+         return  expected_components 
410+ 
411+     @property  
412+     def  expected_configs (self ):
413+         expected_configs  =  []
414+         for  block  in  self .blocks .values ():
415+             for  config  in  block .expected_configs :
416+                 if  config  not  in expected_configs :
417+                     expected_configs .append (config )
418+         return  expected_configs 
419+ 
420+     def  __init__ (self ):
421+         blocks  =  OrderedDict ()
422+         for  block_name , block_cls  in  zip (self .block_names , self .block_classes ):
423+             blocks [block_name ] =  block_cls ()
424+         self .blocks  =  blocks 
425+ 
426+     # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc 
427+     @property  
428+     def  components (self ):
429+         # Combine components from all blocks 
430+         components  =  {}
431+         for  block_name , block  in  self .blocks .items ():
432+             for  key , value  in  block .components .items ():
433+                 # Only update if: 
434+                 # 1. Key doesn't exist yet in components, OR 
435+                 # 2. New value is not None 
436+                 if  key  not  in components  or  value  is  not None :
437+                     components [key ] =  value 
438+         return  components 
439+ 
440+     @property  
441+     def  auxiliaries (self ):
442+         # Combine auxiliaries from all blocks 
443+         auxiliaries  =  {}
444+         for  block_name , block  in  self .blocks .items ():
445+             auxiliaries .update (block .auxiliaries )
446+         return  auxiliaries 
447+ 
448+     @property  
449+     def  configs (self ):
450+         # Combine configs from all blocks 
451+         configs  =  {}
452+         for  block_name , block  in  self .blocks .items ():
453+             configs .update (block .configs )
454+         return  configs 
447455
448456    @property  
449457    def  inputs (self ) ->  List [Tuple [str , Any ]]:
@@ -467,7 +475,11 @@ def intermediates_inputs(self) -> List[str]:
467475    @property  
468476    def  intermediates_outputs (self ) ->  List [str ]:
469477        return  list (set ().union (* (block .intermediates_outputs  for  block  in  self .blocks .values ())))
470- 
478+     
479+     @property  
480+     def  final_intermediates_outputs (self ) ->  List [str ]:
481+         return  next (reversed (self .blocks .values ())).intermediates_outputs 
482+     
471483    @torch .no_grad () 
472484    def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
473485        for  block_name , block  in  self .blocks .items ():
@@ -536,7 +548,8 @@ def __repr__(self):
536548        intermediates_str  =  (
537549            "\n     Intermediates:\n " 
538550            f"      - inputs: { ', ' .join (self .intermediates_inputs )} \n " 
539-             f"      - outputs: { ', ' .join (self .intermediates_outputs )}  
551+             f"      - outputs: { ', ' .join (self .intermediates_outputs )} \n " 
552+             f"      - final outputs: { ', ' .join (self .final_intermediates_outputs )}  
540553        )
541554
542555        return  (
@@ -772,7 +785,7 @@ def __repr__(self):
772785        output  +=  "Pipeline Block:\n " 
773786        output  +=  "--------------\n " 
774787        block  =  self .pipeline_block 
775-         if  isinstance (block , MultiPipelineBlocks ):
788+         if  hasattr (block , "blocks" ):
776789            output  +=  f"{ block .__class__ .__name__ } \n " 
777790            # Add sub-blocks information 
778791            for  sub_block_name , sub_block  in  block .blocks .items ():
@@ -787,13 +800,10 @@ def __repr__(self):
787800            output  +=  "\n " 
788801
789802        # Add final intermediate outputs for SequentialPipelineBlocks 
790-         if  isinstance (block , SequentialPipelineBlocks ):
791-             last_block  =  list (block .blocks .values ())[- 1 ]
792-             if  hasattr (last_block , "intermediates_outputs" ):
793-                 final_outputs  =  last_block .intermediates_outputs 
794-                 final_intermediates_str  =  f"   (final intermediate outputs: { ', ' .join (final_outputs )}  
795-                 output  +=  f"   { final_intermediates_str } \n " 
796-                 output  +=  "\n " 
803+         if  hasattr (block , "final_intermediate_output" ):
804+             final_intermediates_str  =  f"   (final intermediate outputs: { ', ' .join (block .final_intermediate_output )}  
805+             output  +=  f"   { final_intermediates_str } \n " 
806+             output  +=  "\n " 
797807
798808        # List the components registered in the pipeline 
799809        output  +=  "Registered Components:\n " 
0 commit comments