@@ -264,6 +264,18 @@ def inputs(self) -> List[InputParam]:
264264 """List of input parameters. Must be implemented by subclasses."""
265265 return []
266266
267+ def _get_required_inputs (self ):
268+ input_names = []
269+ for input_param in self .inputs :
270+ if input_param .required :
271+ input_names .append (input_param .name )
272+
273+ return input_names
274+
275+ @property
276+ def required_inputs (self ) -> List [InputParam ]:
277+ return self ._get_required_inputs ()
278+
267279 @property
268280 def intermediate_outputs (self ) -> List [OutputParam ]:
269281 """List of intermediate output parameters. Must be implemented by subclasses."""
@@ -492,6 +504,17 @@ def intermediate_output_names(self) -> List[str]:
492504 def output_names (self ) -> List [str ]:
493505 return [output_param .name for output_param in self .outputs ]
494506
507+ @property
508+ def doc (self ):
509+ return make_doc_string (
510+ self .inputs ,
511+ self .outputs ,
512+ self .description ,
513+ class_name = self .__class__ .__name__ ,
514+ expected_components = self .expected_components ,
515+ expected_configs = self .expected_configs ,
516+ )
517+
495518
496519class AutoPipelineBlocks (ModularPipelineBlocks ):
497520 """
@@ -743,7 +766,6 @@ def __repr__(self):
743766 def doc (self ):
744767 return make_doc_string (
745768 self .inputs ,
746- self .intermediate_inputs ,
747769 self .outputs ,
748770 self .description ,
749771 class_name = self .__class__ .__name__ ,
@@ -2394,16 +2416,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
23942416
23952417 # Add inputs to state, using defaults if not provided in the kwargs or the state
23962418 # if same input already in the state, will override it if provided in the kwargs
2397- intermediate_inputs = [inp .name for inp in self .blocks .inputs ]
23982419 for expected_input_param in self .blocks .inputs :
23992420 name = expected_input_param .name
24002421 default = expected_input_param .default
24012422 kwargs_type = expected_input_param .kwargs_type
24022423 if name in passed_kwargs :
2403- if name not in intermediate_inputs :
2404- state .set (name , passed_kwargs .pop (name ), kwargs_type )
2405- else :
2406- state .set (name , passed_kwargs [name ], kwargs_type )
2424+ state .set (name , passed_kwargs .pop (name ), kwargs_type )
24072425 elif name not in state .values :
24082426 state .set (name , default , kwargs_type )
24092427
0 commit comments