@@ -342,6 +342,16 @@ def expected_components(self) -> List[ComponentSpec]:
342342 def expected_configs (self ) -> List [ConfigSpec ]:
343343 return []
344344
345+ @property
346+ def intermediate_inputs (self ) -> List [OutputParam ]:
347+ """List of intermediate output parameters. Must be implemented by subclasses."""
348+ return []
349+
350+ @property
351+ def intermediate_outputs (self ) -> List [OutputParam ]:
352+ """List of intermediate output parameters. Must be implemented by subclasses."""
353+ return []
354+
345355 @classmethod
346356 def from_pretrained (
347357 cls ,
@@ -423,6 +433,60 @@ def init_pipeline(
423433 )
424434 return modular_pipeline
425435
436+ def get_block_state (self , state : PipelineState ) -> dict :
437+ """Get all inputs and intermediates in one dictionary"""
438+ data = {}
439+ state_inputs = self .inputs + self .intermediate_inputs
440+
441+ # Check inputs
442+ for input_param in state_inputs :
443+ if input_param .name :
444+ value = state .get_input (input_param .name ) or state .get_intermediate (input_param .name )
445+ if input_param .required and value is None :
446+ raise ValueError (f"Required input '{ input_param .name } ' is missing" )
447+ elif value is not None or (value is None and input_param .name not in data ):
448+ data [input_param .name ] = value
449+
450+ elif input_param .kwargs_type :
451+ # if kwargs_type is provided, get all inputs with matching kwargs_type
452+ if input_param .kwargs_type not in data :
453+ data [input_param .kwargs_type ] = {}
454+ inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type ) or state .get_intermediate_kwargs (
455+ input_param .kwargs_type
456+ )
457+ if inputs_kwargs :
458+ for k , v in inputs_kwargs .items ():
459+ if v is not None :
460+ data [k ] = v
461+ data [input_param .kwargs_type ][k ] = v
462+
463+ return BlockState (** data )
464+
465+ def set_block_state (self , state : PipelineState , block_state : BlockState ):
466+ for output_param in self .intermediate_outputs :
467+ if not hasattr (block_state , output_param .name ):
468+ raise ValueError (f"Intermediate output '{ output_param .name } ' is missing in block state" )
469+ param = getattr (block_state , output_param .name )
470+ state .set_intermediate (output_param .name , param , output_param .kwargs_type )
471+
472+ for input_param in self .intermediate_inputs :
473+ if input_param .name and hasattr (block_state , input_param .name ):
474+ param = getattr (block_state , input_param .name )
475+ # Only add if the value is different from what's in the state
476+ current_value = state .get_intermediate (input_param .name )
477+ if current_value is not param : # Using identity comparison to check if object was modified
478+ state .set_intermediate (input_param .name , param , input_param .kwargs_type )
479+ elif input_param .kwargs_type :
480+ # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
481+ # we need to first find out which inputs are and loop through them.
482+ intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
483+ for param_name , current_value in intermediate_kwargs .items ():
484+ if not hasattr (block_state , param_name ):
485+ continue
486+ param = getattr (block_state , param_name )
487+ if current_value is not param : # Using identity comparison to check if object was modified
488+ state .set_intermediate (param_name , param , input_param .kwargs_type )
489+
426490 @staticmethod
427491 def combine_inputs (* named_input_lists : List [Tuple [str , List [InputParam ]]]) -> List [InputParam ]:
428492 """
@@ -652,51 +716,6 @@ def doc(self):
652716 expected_configs = self .expected_configs ,
653717 )
654718
655- # YiYi TODO: input and inteermediate inputs with same name? should warn?
656- def get_block_state (self , state : PipelineState ) -> dict :
657- """Get all inputs and intermediates in one dictionary"""
658- data = {}
659-
660- # Check inputs
661- for input_param in self .inputs :
662- if input_param .name :
663- value = state .get_input (input_param .name )
664- if input_param .required and value is None :
665- raise ValueError (f"Required input '{ input_param .name } ' is missing" )
666- elif value is not None or (value is None and input_param .name not in data ):
667- data [input_param .name ] = value
668- elif input_param .kwargs_type :
669- # if kwargs_type is provided, get all inputs with matching kwargs_type
670- if input_param .kwargs_type not in data :
671- data [input_param .kwargs_type ] = {}
672- inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type )
673- if inputs_kwargs :
674- for k , v in inputs_kwargs .items ():
675- if v is not None :
676- data [k ] = v
677- data [input_param .kwargs_type ][k ] = v
678-
679- # Check intermediates
680- for input_param in self .intermediate_inputs :
681- if input_param .name :
682- value = state .get_intermediate (input_param .name )
683- if input_param .required and value is None :
684- raise ValueError (f"Required intermediate input '{ input_param .name } ' is missing" )
685- elif value is not None or (value is None and input_param .name not in data ):
686- data [input_param .name ] = value
687- elif input_param .kwargs_type :
688- # if kwargs_type is provided, get all intermediates with matching kwargs_type
689- if input_param .kwargs_type not in data :
690- data [input_param .kwargs_type ] = {}
691- intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
692- if intermediate_kwargs :
693- for k , v in intermediate_kwargs .items ():
694- if v is not None :
695- if k not in data :
696- data [k ] = v
697- data [input_param .kwargs_type ][k ] = v
698- return BlockState (** data )
699-
700719 def set_block_state (self , state : PipelineState , block_state : BlockState ):
701720 for output_param in self .intermediate_outputs :
702721 if not hasattr (block_state , output_param .name ):
@@ -1633,75 +1652,6 @@ def loop_step(self, components, state: PipelineState, **kwargs):
16331652 def __call__ (self , components , state : PipelineState ) -> PipelineState :
16341653 raise NotImplementedError ("`__call__` method needs to be implemented by the subclass" )
16351654
1636- def get_block_state (self , state : PipelineState ) -> dict :
1637- """Get all inputs and intermediates in one dictionary"""
1638- data = {}
1639-
1640- # Check inputs
1641- for input_param in self .inputs :
1642- if input_param .name :
1643- value = state .get_input (input_param .name )
1644- if input_param .required and value is None :
1645- raise ValueError (f"Required input '{ input_param .name } ' is missing" )
1646- elif value is not None or (value is None and input_param .name not in data ):
1647- data [input_param .name ] = value
1648- elif input_param .kwargs_type :
1649- # if kwargs_type is provided, get all inputs with matching kwargs_type
1650- if input_param .kwargs_type not in data :
1651- data [input_param .kwargs_type ] = {}
1652- inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type )
1653- if inputs_kwargs :
1654- for k , v in inputs_kwargs .items ():
1655- if v is not None :
1656- data [k ] = v
1657- data [input_param .kwargs_type ][k ] = v
1658-
1659- # Check intermediates
1660- for input_param in self .intermediate_inputs :
1661- if input_param .name :
1662- value = state .get_intermediate (input_param .name )
1663- if input_param .required and value is None :
1664- raise ValueError (f"Required intermediate input '{ input_param .name } ' is missing" )
1665- elif value is not None or (value is None and input_param .name not in data ):
1666- data [input_param .name ] = value
1667- elif input_param .kwargs_type :
1668- # if kwargs_type is provided, get all intermediates with matching kwargs_type
1669- if input_param .kwargs_type not in data :
1670- data [input_param .kwargs_type ] = {}
1671- intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
1672- if intermediate_kwargs :
1673- for k , v in intermediate_kwargs .items ():
1674- if v is not None :
1675- if k not in data :
1676- data [k ] = v
1677- data [input_param .kwargs_type ][k ] = v
1678- return BlockState (** data )
1679-
1680- def set_block_state (self , state : PipelineState , block_state : BlockState ):
1681- for output_param in self .intermediate_outputs :
1682- if not hasattr (block_state , output_param .name ):
1683- raise ValueError (f"Intermediate output '{ output_param .name } ' is missing in block state" )
1684- param = getattr (block_state , output_param .name )
1685- state .set_intermediate (output_param .name , param , output_param .kwargs_type )
1686-
1687- for input_param in self .intermediate_inputs :
1688- if input_param .name and hasattr (block_state , input_param .name ):
1689- param = getattr (block_state , input_param .name )
1690- # Only add if the value is different from what's in the state
1691- current_value = state .get_intermediate (input_param .name )
1692- if current_value is not param : # Using identity comparison to check if object was modified
1693- state .set_intermediate (input_param .name , param , input_param .kwargs_type )
1694- elif input_param .kwargs_type :
1695- # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
1696- # we need to first find out which inputs are and loop through them.
1697- intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
1698- for param_name , current_value in intermediate_kwargs .items ():
1699- if not hasattr (block_state , param_name ):
1700- continue
1701- param = getattr (block_state , param_name )
1702- if current_value is not param : # Using identity comparison to check if object was modified
1703- state .set_intermediate (param_name , param , input_param .kwargs_type )
1704-
17051655 @property
17061656 def doc (self ):
17071657 return make_doc_string (
@@ -1974,7 +1924,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
19741924
19751925 # Add inputs to state, using defaults if not provided in the kwargs or the state
19761926 # if same input already in the state, will override it if provided in the kwargs
1977-
19781927 intermediate_inputs = [inp .name for inp in self .blocks .intermediate_inputs ]
19791928 for expected_input_param in self .blocks .inputs :
19801929 name = expected_input_param .name
0 commit comments