@@ -540,8 +540,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
540540
541541 def __init__ (self ):
542542 sub_blocks = InsertableDict ()
543- for block_name , block_cls in zip (self .block_names , self .block_classes ):
544- sub_blocks [block_name ] = block_cls ()
543+ for block_name , block in zip (self .block_names , self .block_classes ):
544+ if inspect .isclass (block ):
545+ sub_blocks [block_name ] = block ()
546+ else :
547+ sub_blocks [block_name ] = block
545548 self .sub_blocks = sub_blocks
546549 if not (len (self .block_classes ) == len (self .block_names ) == len (self .block_trigger_inputs )):
547550 raise ValueError (
@@ -848,8 +851,11 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo
848851
849852 def __init__ (self ):
850853 sub_blocks = InsertableDict ()
851- for block_name , block_cls in zip (self .block_names , self .block_classes ):
852- sub_blocks [block_name ] = block_cls ()
854+ for block_name , block in zip (self .block_names , self .block_classes ):
855+ if inspect .isclass (block ):
856+ sub_blocks [block_name ] = block ()
857+ else :
858+ sub_blocks [block_name ] = block
853859 self .sub_blocks = sub_blocks
854860
855861 def _get_inputs (self ):
@@ -1272,8 +1278,11 @@ def outputs(self) -> List[str]:
12721278
12731279 def __init__ (self ):
12741280 sub_blocks = InsertableDict ()
1275- for block_name , block_cls in zip (self .block_names , self .block_classes ):
1276- sub_blocks [block_name ] = block_cls ()
1281+ for block_name , block in zip (self .block_names , self .block_classes ):
1282+ if inspect .isclass (block ):
1283+ sub_blocks [block_name ] = block ()
1284+ else :
1285+ sub_blocks [block_name ] = block
12771286 self .sub_blocks = sub_blocks
12781287
12791288 @classmethod
0 commit comments