@@ -184,6 +184,23 @@ def __init__(self, **kwargs):
184184 for key , value in kwargs .items ():
185185 setattr (self , key , value )
186186
187+ def __getitem__ (self , key : str ):
188+ # allows block_state["foo"]
189+ return getattr (self , key , None )
190+
191+ def __setitem__ (self , key : str , value : Any ):
192+ # allows block_state["foo"] = "bar"
193+ setattr (self , key , value )
194+
195+ def as_dict (self ):
196+ """
197+ Convert BlockState to a dictionary.
198+
199+ Returns:
200+ Dict[str, Any]: Dictionary containing all attributes of the BlockState
201+ """
202+ return {key : value for key , value in self .__dict__ .items ()}
203+
187204 def __repr__ (self ):
188205 def format_value (v ):
189206 # Handle tensors directly
@@ -523,8 +540,12 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
523540
524541 for block_name , inputs in named_input_lists :
525542 for input_param in inputs :
526- if input_param .name in combined_dict :
527- current_param = combined_dict [input_param .name ]
543+ if input_param .name is None and input_param .kwargs_type is not None :
544+ input_name = "*_" + input_param .kwargs_type
545+ else :
546+ input_name = input_param .name
547+ if input_name in combined_dict :
548+ current_param = combined_dict [input_name ]
528549 if (current_param .default is not None and
529550 input_param .default is not None and
530551 current_param .default != input_param .default ):
@@ -557,7 +578,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
557578
558579 for block_name , outputs in named_output_lists :
559580 for output_param in outputs :
560- if output_param .name not in combined_dict :
581+ if ( output_param .name not in combined_dict ) or ( combined_dict [ output_param . name ]. kwargs_type is None and output_param . kwargs_type is not None ) :
561582 combined_dict [output_param .name ] = output_param
562583
563584 return list (combined_dict .values ())
@@ -919,6 +940,9 @@ def required_intermediates_inputs(self) -> List[str]:
919940 # YiYi TODO: add test for this
920941 @property
921942 def inputs (self ) -> List [Tuple [str , Any ]]:
943+ return self .get_inputs ()
944+
945+ def get_inputs (self ):
922946 named_inputs = [(name , block .inputs ) for name , block in self .blocks .items ()]
923947 combined_inputs = combine_inputs (* named_inputs )
924948 # mark Required inputs only if that input is required any of the blocks
@@ -931,6 +955,9 @@ def inputs(self) -> List[Tuple[str, Any]]:
931955
932956 @property
933957 def intermediates_inputs (self ) -> List [str ]:
958+ return self .get_intermediates_inputs ()
959+
960+ def get_intermediates_inputs (self ):
934961 inputs = []
935962 outputs = set ()
936963
@@ -1169,7 +1196,262 @@ def doc(self):
11691196 expected_configs = self .expected_configs
11701197 )
11711198
1199+ #YiYi TODO: __repr__
1200+ class LoopSequentialPipelineBlocks (ModularPipelineMixin ):
1201+ """
1202+ A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
1203+ """
1204+
1205+ model_name = None
1206+ block_classes = []
1207+ block_names = []
1208+
1209+ @property
1210+ def description (self ) -> str :
1211+ """Description of the block. Must be implemented by subclasses."""
1212+ raise NotImplementedError ("description method must be implemented in subclasses" )
1213+
1214+ @property
1215+ def loop_expected_components (self ) -> List [ComponentSpec ]:
1216+ return []
1217+
1218+ @property
1219+ def loop_expected_configs (self ) -> List [ConfigSpec ]:
1220+ return []
1221+
1222+ @property
1223+ def loop_inputs (self ) -> List [InputParam ]:
1224+ """List of input parameters. Must be implemented by subclasses."""
1225+ return []
1226+
1227+ @property
1228+ def loop_intermediates_inputs (self ) -> List [InputParam ]:
1229+ """List of intermediate input parameters. Must be implemented by subclasses."""
1230+ return []
1231+
1232+ @property
1233+ def loop_intermediates_outputs (self ) -> List [OutputParam ]:
1234+ """List of intermediate output parameters. Must be implemented by subclasses."""
1235+ return []
1236+
1237+
1238+ @property
1239+ def loop_required_inputs (self ) -> List [str ]:
1240+ input_names = []
1241+ for input_param in self .loop_inputs :
1242+ if input_param .required :
1243+ input_names .append (input_param .name )
1244+ return input_names
1245+
1246+ @property
1247+ def loop_required_intermediates_inputs (self ) -> List [str ]:
1248+ input_names = []
1249+ for input_param in self .loop_intermediates_inputs :
1250+ if input_param .required :
1251+ input_names .append (input_param .name )
1252+ return input_names
1253+
1254+ # modified from SequentialPipelineBlocks to include loop_expected_components
1255+ @property
1256+ def expected_components (self ):
1257+ expected_components = []
1258+ for block in self .blocks .values ():
1259+ for component in block .expected_components :
1260+ if component not in expected_components :
1261+ expected_components .append (component )
1262+ for component in self .loop_expected_components :
1263+ if component not in expected_components :
1264+ expected_components .append (component )
1265+ return expected_components
1266+
1267+ # modified from SequentialPipelineBlocks to include loop_expected_configs
1268+ @property
1269+ def expected_configs (self ):
1270+ expected_configs = []
1271+ for block in self .blocks .values ():
1272+ for config in block .expected_configs :
1273+ if config not in expected_configs :
1274+ expected_configs .append (config )
1275+ for config in self .loop_expected_configs :
1276+ if config not in expected_configs :
1277+ expected_configs .append (config )
1278+ return expected_configs
1279+
1280+ # modified from SequentialPipelineBlocks to include loop_inputs
1281+ def get_inputs (self ):
1282+ named_inputs = [(name , block .inputs ) for name , block in self .blocks .items ()]
1283+ named_inputs .append (("loop" , self .loop_inputs ))
1284+ combined_inputs = combine_inputs (* named_inputs )
1285+ # mark Required inputs only if that input is required any of the blocks
1286+ for input_param in combined_inputs :
1287+ if input_param .name in self .required_inputs :
1288+ input_param .required = True
1289+ else :
1290+ input_param .required = False
1291+ return combined_inputs
1292+
1293+ # Copied from SequentialPipelineBlocks
1294+ @property
1295+ def inputs (self ):
1296+ return self .get_inputs ()
1297+
1298+
1299+ # modified from SequentialPipelineBlocks to include loop_intermediates_inputs
1300+ @property
1301+ def intermediates_inputs (self ):
1302+ intermediates = self .get_intermediates_inputs ()
1303+ intermediate_names = [input .name for input in intermediates ]
1304+ for loop_intermediate_input in self .loop_intermediates_inputs :
1305+ if loop_intermediate_input .name not in intermediate_names :
1306+ intermediates .append (loop_intermediate_input )
1307+ return intermediates
1308+
1309+
1310+ # Copied from SequentialPipelineBlocks
1311+ def get_intermediates_inputs (self ):
1312+ inputs = []
1313+ outputs = set ()
1314+
1315+ # Go through all blocks in order
1316+ for block in self .blocks .values ():
1317+ # Add inputs that aren't in outputs yet
1318+ inputs .extend (input_name for input_name in block .intermediates_inputs if input_name .name not in outputs )
1319+
1320+ # Only add outputs if the block cannot be skipped
1321+ should_add_outputs = True
1322+ if hasattr (block , "block_trigger_inputs" ) and None not in block .block_trigger_inputs :
1323+ should_add_outputs = False
1324+
1325+ if should_add_outputs :
1326+ # Add this block's outputs
1327+ block_intermediates_outputs = [out .name for out in block .intermediates_outputs ]
1328+ outputs .update (block_intermediates_outputs )
1329+ return inputs
1330+
11721331
1332+ # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
1333+ @property
1334+ def required_inputs (self ) -> List [str ]:
1335+ # Get the first block from the dictionary
1336+ first_block = next (iter (self .blocks .values ()))
1337+ required_by_any = set (getattr (first_block , "required_inputs" , set ()))
1338+
1339+ required_by_loop = set (getattr (self , "loop_required_inputs" , set ()))
1340+ required_by_any .update (required_by_loop )
1341+
1342+ # Union with required inputs from all other blocks
1343+ for block in list (self .blocks .values ())[1 :]:
1344+ block_required = set (getattr (block , "required_inputs" , set ()))
1345+ required_by_any .update (block_required )
1346+
1347+ return list (required_by_any )
1348+
1349+ # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
1350+ @property
1351+ def required_intermediates_inputs (self ) -> List [str ]:
1352+ required_intermediates_inputs = []
1353+ for input_param in self .intermediates_inputs :
1354+ if input_param .required :
1355+ required_intermediates_inputs .append (input_param .name )
1356+ for input_param in self .loop_intermediates_inputs :
1357+ if input_param .required :
1358+ required_intermediates_inputs .append (input_param .name )
1359+ return required_intermediates_inputs
1360+
1361+
1362+ # YiYi TODO: this need to be thought about more
1363+ # modified from SequentialPipelineBlocks to include loop_intermediates_outputs
1364+ @property
1365+ def intermediates_outputs (self ) -> List [str ]:
1366+ named_outputs = [(name , block .intermediates_outputs ) for name , block in self .blocks .items ()]
1367+ combined_outputs = combine_outputs (* named_outputs )
1368+ for output in self .loop_intermediates_outputs :
1369+ if output .name not in set ([output .name for output in combined_outputs ]):
1370+ combined_outputs .append (output )
1371+ return combined_outputs
1372+
1373+ # YiYi TODO: this need to be thought about more
1374+ # copied from SequentialPipelineBlocks
1375+ @property
1376+ def outputs (self ) -> List [str ]:
1377+ return next (reversed (self .blocks .values ())).intermediates_outputs
1378+
1379+
1380+ def __init__ (self ):
1381+ blocks = OrderedDict ()
1382+ for block_name , block_cls in zip (self .block_names , self .block_classes ):
1383+ blocks [block_name ] = block_cls ()
1384+ self .blocks = blocks
1385+
1386+ def loop_step (self , components , state : PipelineState , ** kwargs ):
1387+
1388+ for block_name , block in self .blocks .items ():
1389+ try :
1390+ components , state = block (components , state , ** kwargs )
1391+ except Exception as e :
1392+ error_msg = (
1393+ f"\n Error in block: ({ block_name } , { block .__class__ .__name__ } )\n "
1394+ f"Error details: { str (e )} \n "
1395+ f"Traceback:\n { traceback .format_exc ()} "
1396+ )
1397+ logger .error (error_msg )
1398+ raise
1399+ return components , state
1400+
1401+ def __call__ (self , components , state : PipelineState ) -> PipelineState :
1402+ raise NotImplementedError ("`__call__` method needs to be implemented by the subclass" )
1403+
1404+
1405+ def get_block_state (self , state : PipelineState ) -> dict :
1406+ """Get all inputs and intermediates in one dictionary"""
1407+ data = {}
1408+
1409+ # Check inputs
1410+ for input_param in self .inputs :
1411+ if input_param .name :
1412+ value = state .get_input (input_param .name )
1413+ if input_param .required and value is None :
1414+ raise ValueError (f"Required input '{ input_param .name } ' is missing" )
1415+ elif value is not None or (value is None and input_param .name not in data ):
1416+ data [input_param .name ] = value
1417+ elif input_param .kwargs_type :
1418+ # if kwargs_type is provided, get all inputs with matching kwargs_type
1419+ if input_param .kwargs_type not in data :
1420+ data [input_param .kwargs_type ] = {}
1421+ inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type )
1422+ if inputs_kwargs :
1423+ for k , v in inputs_kwargs .items ():
1424+ if v is not None :
1425+ data [k ] = v
1426+ data [input_param .kwargs_type ][k ] = v
1427+
1428+ # Check intermediates
1429+ for input_param in self .intermediates_inputs :
1430+ if input_param .name :
1431+ value = state .get_intermediate (input_param .name )
1432+ if input_param .required and value is None :
1433+ raise ValueError (f"Required intermediate input '{ input_param .name } ' is missing" )
1434+ elif value is not None or (value is None and input_param .name not in data ):
1435+ data [input_param .name ] = value
1436+ elif input_param .kwargs_type :
1437+ # if kwargs_type is provided, get all intermediates with matching kwargs_type
1438+ if input_param .kwargs_type not in data :
1439+ data [input_param .kwargs_type ] = {}
1440+ intermediates_kwargs = state .get_intermediates_kwargs (input_param .kwargs_type )
1441+ if intermediates_kwargs :
1442+ for k , v in intermediates_kwargs .items ():
1443+ if v is not None :
1444+ if k not in data :
1445+ data [k ] = v
1446+ data [input_param .kwargs_type ][k ] = v
1447+ return BlockState (** data )
1448+
1449+ def add_block_state (self , state : PipelineState , block_state : BlockState ):
1450+ for output_param in self .intermediates_outputs :
1451+ if not hasattr (block_state , output_param .name ):
1452+ raise ValueError (f"Intermediate output '{ output_param .name } ' is missing in block state" )
1453+ param = getattr (block_state , output_param .name )
1454+ state .add_intermediate (output_param .name , param , output_param .kwargs_type )
11731455
11741456# YiYi TODO:
11751457# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
0 commit comments