@@ -282,7 +282,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None,
282282 state = PipelineState ()
283283
284284 if not hasattr (self , "loader" ):
285- logger .warning ("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline." )
285+ logger .info ("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline." )
286286 self .loader = None
287287
288288 # Make a copy of the input kwargs
@@ -313,7 +313,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None,
313313
314314 # Warn about unexpected inputs
315315 if len (passed_kwargs ) > 0 :
316- logger . warning (f"Unexpected input '{ passed_kwargs .keys ()} ' provided. This input will be ignored." )
316+ warnings . warn (f"Unexpected input '{ passed_kwargs .keys ()} ' provided. This input will be ignored." )
317317 # Run the pipeline
318318 with torch .no_grad ():
319319 try :
@@ -373,7 +373,6 @@ def expected_configs(self) -> List[ConfigSpec]:
373373 return []
374374
375375
376- # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
377376 @property
378377 def inputs (self ) -> List [InputParam ]:
379378 """List of input parameters. Must be implemented by subclasses."""
@@ -389,27 +388,40 @@ def intermediates_outputs(self) -> List[OutputParam]:
389388 """List of intermediate output parameters. Must be implemented by subclasses."""
390389 return []
391390
391+ def _get_outputs (self ):
392+ return self .intermediates_outputs
393+
394+ # YiYi TODO: is it too easy for user to unintentionally override these properties?
392395 # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
393396 @property
394397 def outputs (self ) -> List [OutputParam ]:
395- return self .intermediates_outputs
398+ return self ._get_outputs ()
396399
397- @property
398- def required_inputs (self ) -> List [str ]:
400+ def _get_required_inputs (self ):
399401 input_names = []
400402 for input_param in self .inputs :
401403 if input_param .required :
402404 input_names .append (input_param .name )
403405 return input_names
404406
405407 @property
406- def required_intermediates_inputs (self ) -> List [str ]:
408+ def required_inputs (self ) -> List [str ]:
409+ return self ._get_required_inputs ()
410+
411+
412+ def _get_required_intermediates_inputs (self ):
407413 input_names = []
408414 for input_param in self .intermediates_inputs :
409415 if input_param .required :
410416 input_names .append (input_param .name )
411417 return input_names
412418
419+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
420+ # intermediate_inputs is by default required, unless you manually handle it inside the block
421+ @property
422+ def required_intermediates_inputs (self ) -> List [str ]:
423+ return self ._get_required_intermediates_inputs ()
424+
413425
414426 def __call__ (self , pipeline , state : PipelineState ) -> PipelineState :
415427 raise NotImplementedError ("__call__ method must be implemented in subclasses" )
@@ -521,6 +533,30 @@ def add_block_state(self, state: PipelineState, block_state: BlockState):
521533 raise ValueError (f"Intermediate output '{ output_param .name } ' is missing in block state" )
522534 param = getattr (block_state , output_param .name )
523535 state .add_intermediate (output_param .name , param , output_param .kwargs_type )
536+
537+ for input_param in self .intermediates_inputs :
538+ if hasattr (block_state , input_param .name ):
539+ param = getattr (block_state , input_param .name )
540+ # Only add if the value is different from what's in the state
541+ current_value = state .get_intermediate (input_param .name )
542+ if current_value is not param : # Using identity comparison to check if object was modified
543+ state .add_intermediate (input_param .name , param , input_param .kwargs_type )
544+
545+ for input_param in self .intermediates_inputs :
546+ if input_param .name and hasattr (block_state , input_param .name ):
547+ param = getattr (block_state , input_param .name )
548+ # Only add if the value is different from what's in the state
549+ current_value = state .get_intermediate (input_param .name )
550+ if current_value is not param : # Using identity comparison to check if object was modified
551+ state .add_intermediate (input_param .name , param , input_param .kwargs_type )
552+ elif input_param .kwargs_type :
553+ # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
554+ # we need to first find out which inputs are and loop through them.
555+ intermediates_kwargs = state .get_intermediates_kwargs (input_param .kwargs_type )
556+ for param_name , current_value in intermediates_kwargs .items ():
557+ param = getattr (block_state , param_name )
558+ if current_value is not param : # Using identity comparison to check if object was modified
559+ state .add_intermediate (param_name , param , input_param .kwargs_type )
524560
525561
526562def combine_inputs (* named_input_lists : List [Tuple [str , List [InputParam ]]]) -> List [InputParam ]:
@@ -550,16 +586,16 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
550586 input_param .default is not None and
551587 current_param .default != input_param .default ):
552588 warnings .warn (
553- f"Multiple different default values found for input '{ input_param . name } ': "
554- f"{ current_param .default } (from block '{ value_sources [input_param . name ]} ') and "
589+ f"Multiple different default values found for input '{ input_name } ': "
590+ f"{ current_param .default } (from block '{ value_sources [input_name ]} ') and "
555591 f"{ input_param .default } (from block '{ block_name } '). Using { current_param .default } ."
556592 )
557593 if current_param .default is None and input_param .default is not None :
558- combined_dict [input_param . name ] = input_param
559- value_sources [input_param . name ] = block_name
594+ combined_dict [input_name ] = input_param
595+ value_sources [input_name ] = block_name
560596 else :
561- combined_dict [input_param . name ] = input_param
562- value_sources [input_param . name ] = block_name
597+ combined_dict [input_name ] = input_param
598+ value_sources [input_name ] = block_name
563599
564600 return list (combined_dict .values ())
565601
@@ -661,7 +697,9 @@ def required_inputs(self) -> List[str]:
661697 required_by_all .intersection_update (block_required )
662698
663699 return list (required_by_all )
664-
700+
701+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
702+ # intermediate_inputs is by default required, unless you manually handle it inside the block
665703 @property
666704 def required_intermediates_inputs (self ) -> List [str ]:
667705 first_block = next (iter (self .blocks .values ()))
@@ -838,14 +876,21 @@ def __repr__(self):
838876 indented_desc += '\n ' + '\n ' .join (' ' + line for line in desc_lines [1 :])
839877 blocks_str += f" Description: { indented_desc } \n \n "
840878
841- return (
842- f"{ header } \n "
843- f"{ desc } \n \n "
844- f"{ components_str } \n \n "
845- f"{ configs_str } \n \n "
846- f"{ blocks_str } "
847- f")"
848- )
879+ # Build the representation with conditional sections
880+ result = f"{ header } \n { desc } "
881+
882+ # Only add components section if it has content
883+ if components_str .strip ():
884+ result += f"\n \n { components_str } "
885+
886+ # Only add configs section if it has content
887+ if configs_str .strip ():
888+ result += f"\n \n { configs_str } "
889+
890+ # Always add blocks section
891+ result += f"\n \n { blocks_str } )"
892+
893+ return result
849894
850895
851896 @property
@@ -867,13 +912,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
867912 block_classes = []
868913 block_names = []
869914
870- @property
871- def model_name (self ):
872- return next (iter (self .blocks .values ())).model_name
873915
874916 @property
875917 def description (self ):
876918 return ""
919+
920+ @property
921+ def model_name (self ):
922+ return next (iter (self .blocks .values ())).model_name
923+
877924
878925 @property
879926 def expected_components (self ):
@@ -929,6 +976,8 @@ def required_inputs(self) -> List[str]:
929976
930977 return list (required_by_any )
931978
979+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
980+ # intermediate_inputs is by default required, unless you manually handle it inside the block
932981 @property
933982 def required_intermediates_inputs (self ) -> List [str ]:
934983 required_intermediates_inputs = []
@@ -960,11 +1009,15 @@ def intermediates_inputs(self) -> List[str]:
9601009 def get_intermediates_inputs (self ):
9611010 inputs = []
9621011 outputs = set ()
1012+ added_inputs = set ()
9631013
9641014 # Go through all blocks in order
9651015 for block in self .blocks .values ():
9661016 # Add inputs that aren't in outputs yet
967- inputs .extend (input_name for input_name in block .intermediates_inputs if input_name .name not in outputs )
1017+ for inp in block .intermediates_inputs :
1018+ if inp .name not in outputs and inp .name not in added_inputs :
1019+ inputs .append (inp )
1020+ added_inputs .add (inp .name )
9681021
9691022 # Only add outputs if the block cannot be skipped
9701023 should_add_outputs = True
@@ -1176,14 +1229,21 @@ def __repr__(self):
11761229 indented_desc += '\n ' + '\n ' .join (' ' + line for line in desc_lines [1 :])
11771230 blocks_str += f" Description: { indented_desc } \n \n "
11781231
1179- return (
1180- f"{ header } \n "
1181- f"{ desc } \n \n "
1182- f"{ components_str } \n \n "
1183- f"{ configs_str } \n \n "
1184- f"{ blocks_str } "
1185- f")"
1186- )
1232+ # Build the representation with conditional sections
1233+ result = f"{ header } \n { desc } "
1234+
1235+ # Only add components section if it has content
1236+ if components_str .strip ():
1237+ result += f"\n \n { components_str } "
1238+
1239+ # Only add configs section if it has content
1240+ if configs_str .strip ():
1241+ result += f"\n \n { configs_str } "
1242+
1243+ # Always add blocks section
1244+ result += f"\n \n { blocks_str } )"
1245+
1246+ return result
11871247
11881248
11891249 @property
@@ -1348,7 +1408,8 @@ def required_inputs(self) -> List[str]:
13481408
13491409 return list (required_by_any )
13501410
1351- # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
1411+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
1412+ # intermediate_inputs is by default required, unless you manually handle it inside the block
13521413 @property
13531414 def required_intermediates_inputs (self ) -> List [str ]:
13541415 required_intermediates_inputs = []
@@ -1384,6 +1445,22 @@ def __init__(self):
13841445 for block_name , block_cls in zip (self .block_names , self .block_classes ):
13851446 blocks [block_name ] = block_cls ()
13861447 self .blocks = blocks
1448+
1449+ @classmethod
1450+ def from_blocks_dict (cls , blocks_dict : Dict [str , Any ]) -> "LoopSequentialPipelineBlocks" :
1451+ """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks.
1452+
1453+ Args:
1454+ blocks_dict: Dictionary mapping block names to block instances
1455+
1456+ Returns:
1457+ A new LoopSequentialPipelineBlocks instance
1458+ """
1459+ instance = cls ()
1460+ instance .block_classes = [block .__class__ for block in blocks_dict .values ()]
1461+ instance .block_names = list (blocks_dict .keys ())
1462+ instance .blocks = blocks_dict
1463+ return instance
13871464
13881465 def loop_step (self , components , state : PipelineState , ** kwargs ):
13891466
@@ -1455,6 +1532,100 @@ def add_block_state(self, state: PipelineState, block_state: BlockState):
14551532 param = getattr (block_state , output_param .name )
14561533 state .add_intermediate (output_param .name , param , output_param .kwargs_type )
14571534
1535+ for input_param in self .intermediates_inputs :
1536+ if input_param .name and hasattr (block_state , input_param .name ):
1537+ param = getattr (block_state , input_param .name )
1538+ # Only add if the value is different from what's in the state
1539+ current_value = state .get_intermediate (input_param .name )
1540+ if current_value is not param : # Using identity comparison to check if object was modified
1541+ state .add_intermediate (input_param .name , param , input_param .kwargs_type )
1542+ elif input_param .kwargs_type :
1543+ # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
1544+ # we need to first find out which inputs are and loop through them.
1545+ intermediates_kwargs = state .get_intermediates_kwargs (input_param .kwargs_type )
1546+ for param_name , current_value in intermediates_kwargs .items ():
1547+ if not hasattr (block_state , param_name ):
1548+ continue
1549+ param = getattr (block_state , param_name )
1550+ if current_value is not param : # Using identity comparison to check if object was modified
1551+ state .add_intermediate (param_name , param , input_param .kwargs_type )
1552+
1553+
1554+ @property
1555+ def doc (self ):
1556+ return make_doc_string (
1557+ self .inputs ,
1558+ self .intermediates_inputs ,
1559+ self .outputs ,
1560+ self .description ,
1561+ class_name = self .__class__ .__name__ ,
1562+ expected_components = self .expected_components ,
1563+ expected_configs = self .expected_configs
1564+ )
1565+
1566+ # modified from SequentialPipelineBlocks,
1567+ #(does not need trigger_inputs related part so removed them,
1568+ # do not need to support auto block for loop blocks)
1569+ def __repr__ (self ):
1570+ class_name = self .__class__ .__name__
1571+ base_class = self .__class__ .__bases__ [0 ].__name__
1572+ header = (
1573+ f"{ class_name } (\n Class: { base_class } \n "
1574+ if base_class and base_class != "object"
1575+ else f"{ class_name } (\n "
1576+ )
1577+
1578+ # Format description with proper indentation
1579+ desc_lines = self .description .split ('\n ' )
1580+ desc = []
1581+ # First line with "Description:" label
1582+ desc .append (f" Description: { desc_lines [0 ]} " )
1583+ # Subsequent lines with proper indentation
1584+ if len (desc_lines ) > 1 :
1585+ desc .extend (f" { line } " for line in desc_lines [1 :])
1586+ desc = '\n ' .join (desc ) + '\n '
1587+
1588+ # Components section - focus only on expected components
1589+ expected_components = getattr (self , "expected_components" , [])
1590+ components_str = format_components (expected_components , indent_level = 2 , add_empty_lines = False )
1591+
1592+ # Configs section - use format_configs with add_empty_lines=False
1593+ expected_configs = getattr (self , "expected_configs" , [])
1594+ configs_str = format_configs (expected_configs , indent_level = 2 , add_empty_lines = False )
1595+
1596+ # Blocks section - moved to the end with simplified format
1597+ blocks_str = " Blocks:\n "
1598+ for i , (name , block ) in enumerate (self .blocks .items ()):
1599+
1600+ # For SequentialPipelineBlocks, show execution order
1601+ blocks_str += f" [{ i } ] { name } ({ block .__class__ .__name__ } )\n "
1602+
1603+ # Add block description
1604+ desc_lines = block .description .split ('\n ' )
1605+ indented_desc = desc_lines [0 ]
1606+ if len (desc_lines ) > 1 :
1607+ indented_desc += '\n ' + '\n ' .join (' ' + line for line in desc_lines [1 :])
1608+ blocks_str += f" Description: { indented_desc } \n \n "
1609+
1610+ # Build the representation with conditional sections
1611+ result = f"{ header } \n { desc } "
1612+
1613+ # Only add components section if it has content
1614+ if components_str .strip ():
1615+ result += f"\n \n { components_str } "
1616+
1617+ # Only add configs section if it has content
1618+ if configs_str .strip ():
1619+ result += f"\n \n { configs_str } "
1620+
1621+ # Always add blocks section
1622+ result += f"\n \n { blocks_str } )"
1623+
1624+ return result
1625+
1626+
1627+
1628+
14581629# YiYi TODO:
14591630# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
14601631# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader
0 commit comments