|
45 | 45 | OutputParam, |
46 | 46 | format_components, |
47 | 47 | format_configs, |
48 | | - format_inputs_short, |
49 | | - format_intermediates_short, |
50 | 48 | make_doc_string, |
51 | 49 | ) |
52 | 50 |
|
@@ -142,12 +140,8 @@ def format_value(v): |
142 | 140 | values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items()) |
143 | 141 | kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items()) |
144 | 142 |
|
145 | | - return ( |
146 | | - f"PipelineState(\n" |
147 | | - f" values={{\n{values_str}\n }},\n" |
148 | | - f" kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n" |
149 | | - f")" |
150 | | - ) |
| 143 | + return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)" |
| 144 | + |
151 | 145 |
|
152 | 146 | @dataclass |
153 | 147 | class BlockState: |
@@ -402,20 +396,21 @@ def set_block_state(self, state: PipelineState, block_state: BlockState): |
402 | 396 | current_value = state.get(input_param.name) |
403 | 397 | if current_value is not param: # Using identity comparison to check if object was modified |
404 | 398 | state.set(input_param.name, param, input_param.kwargs_type) |
| 399 | + |
405 | 400 | elif input_param.kwargs_type: |
406 | | - import ipdb; ipdb.set_trace() |
407 | 401 | # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters |
408 | 402 | # we need to first find out which inputs are and loop through them. |
409 | 403 | intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type) |
410 | 404 | for param_name, current_value in intermediate_kwargs.items(): |
411 | | - try: |
412 | | - if not hasattr(block_state, param_name): |
413 | | - continue |
414 | | - param = getattr(block_state, param_name) |
415 | | - if current_value is not param: # Using identity comparison to check if object was modified |
416 | | - state.set(param_name, param, input_param.kwargs_type) |
417 | | - except: |
418 | | - import ipdb; ipdb.set_trace() |
| 405 | + if param_name is None: |
| 406 | + continue |
| 407 | + |
| 408 | + if not hasattr(block_state, param_name): |
| 409 | + continue |
| 410 | + |
| 411 | + param = getattr(block_state, param_name) |
| 412 | + if current_value is not param: # Using identity comparison to check if object was modified |
| 413 | + state.set(param_name, param, input_param.kwargs_type) |
419 | 414 |
|
420 | 415 | @staticmethod |
421 | 416 | def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: |
@@ -496,200 +491,6 @@ def output_names(self) -> List[str]: |
496 | 491 | return [output_param.name for output_param in self.outputs] |
497 | 492 |
|
498 | 493 |
|
499 | | -class PipelineBlock(ModularPipelineBlocks): |
500 | | - """ |
501 | | - A Pipeline Block is the basic building block of a Modular Pipeline. |
502 | | -
|
503 | | - This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the |
504 | | - library implements for all the pipeline blocks (such as loading or saving etc.) |
505 | | -
|
506 | | - <Tip warning={true}> |
507 | | -
|
508 | | - This is an experimental feature and is likely to change in the future. |
509 | | -
|
510 | | - </Tip> |
511 | | -
|
512 | | - Args: |
513 | | - description (str, optional): A description of the block, defaults to None. Define as a property in subclasses. |
514 | | - expected_components (List[ComponentSpec], optional): |
515 | | - A list of components that are expected to be used in the block, defaults to []. To override, define as a |
516 | | - property in subclasses. |
517 | | - expected_configs (List[ConfigSpec], optional): |
518 | | - A list of configs that are expected to be used in the block, defaults to []. To override, define as a |
519 | | - property in subclasses. |
520 | | - inputs (List[InputParam], optional): |
521 | | - A list of inputs that are expected to be used in the block, defaults to []. To override, define as a |
522 | | - property in subclasses. |
523 | | - intermediate_inputs (List[InputParam], optional): |
524 | | - A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, |
525 | | - define as a property in subclasses. |
526 | | - intermediate_outputs (List[OutputParam], optional): |
527 | | - A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, |
528 | | - define as a property in subclasses. |
529 | | - outputs (List[OutputParam], optional): |
530 | | - A list of outputs that are expected to be used in the block, defaults to []. To override, define as a |
531 | | - property in subclasses. |
532 | | - required_inputs (List[str], optional): |
533 | | - A list of required inputs that are expected to be used in the block, defaults to []. To override, define as |
534 | | - a property in subclasses. |
535 | | - required_intermediate_inputs (List[str], optional): |
536 | | - A list of required intermediate inputs that are expected to be used in the block, defaults to []. To |
537 | | - override, define as a property in subclasses. |
538 | | - required_intermediate_outputs (List[str], optional): |
539 | | - A list of required intermediate outputs that are expected to be used in the block, defaults to []. To |
540 | | - override, define as a property in subclasses. |
541 | | - """ |
542 | | - |
543 | | - model_name = None |
544 | | - |
545 | | - def __init__(self): |
546 | | - self.sub_blocks = InsertableDict() |
547 | | - |
548 | | - @property |
549 | | - def description(self) -> str: |
550 | | - """Description of the block. Must be implemented by subclasses.""" |
551 | | - # raise NotImplementedError("description method must be implemented in subclasses") |
552 | | - return "TODO: add a description" |
553 | | - |
554 | | - @property |
555 | | - def expected_components(self) -> List[ComponentSpec]: |
556 | | - return [] |
557 | | - |
558 | | - @property |
559 | | - def expected_configs(self) -> List[ConfigSpec]: |
560 | | - return [] |
561 | | - |
562 | | - @property |
563 | | - def inputs(self) -> List[InputParam]: |
564 | | - """List of input parameters. Must be implemented by subclasses.""" |
565 | | - return [] |
566 | | - |
567 | | - @property |
568 | | - def intermediate_inputs(self) -> List[InputParam]: |
569 | | - """List of intermediate input parameters. Must be implemented by subclasses.""" |
570 | | - return [] |
571 | | - |
572 | | - @property |
573 | | - def intermediate_outputs(self) -> List[OutputParam]: |
574 | | - """List of intermediate output parameters. Must be implemented by subclasses.""" |
575 | | - return [] |
576 | | - |
577 | | - def _get_outputs(self): |
578 | | - return self.intermediate_outputs |
579 | | - |
580 | | - # YiYi TODO: is it too easy for user to unintentionally override these properties? |
581 | | - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks |
582 | | - @property |
583 | | - def outputs(self) -> List[OutputParam]: |
584 | | - return self._get_outputs() |
585 | | - |
586 | | - def _get_required_inputs(self): |
587 | | - input_names = [] |
588 | | - for input_param in self.inputs: |
589 | | - if input_param.required: |
590 | | - input_names.append(input_param.name) |
591 | | - return input_names |
592 | | - |
593 | | - @property |
594 | | - def required_inputs(self) -> List[str]: |
595 | | - return self._get_required_inputs() |
596 | | - |
597 | | - def _get_required_intermediate_inputs(self): |
598 | | - input_names = [] |
599 | | - for input_param in self.intermediate_inputs: |
600 | | - if input_param.required: |
601 | | - input_names.append(input_param.name) |
602 | | - return input_names |
603 | | - |
604 | | - # YiYi TODO: maybe we do not need this, it is only used in docstring, |
605 | | - # intermediate_inputs is by default required, unless you manually handle it inside the block |
606 | | - @property |
607 | | - def required_intermediate_inputs(self) -> List[str]: |
608 | | - return self._get_required_intermediate_inputs() |
609 | | - |
610 | | - def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
611 | | - raise NotImplementedError("__call__ method must be implemented in subclasses") |
612 | | - |
613 | | - def __repr__(self): |
614 | | - class_name = self.__class__.__name__ |
615 | | - base_class = self.__class__.__bases__[0].__name__ |
616 | | - |
617 | | - # Format description with proper indentation |
618 | | - desc_lines = self.description.split("\n") |
619 | | - desc = [] |
620 | | - # First line with "Description:" label |
621 | | - desc.append(f" Description: {desc_lines[0]}") |
622 | | - # Subsequent lines with proper indentation |
623 | | - if len(desc_lines) > 1: |
624 | | - desc.extend(f" {line}" for line in desc_lines[1:]) |
625 | | - desc = "\n".join(desc) + "\n" |
626 | | - |
627 | | - # Components section - use format_components with add_empty_lines=False |
628 | | - expected_components = getattr(self, "expected_components", []) |
629 | | - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) |
630 | | - components = " " + components_str.replace("\n", "\n ") |
631 | | - |
632 | | - # Configs section - use format_configs with add_empty_lines=False |
633 | | - expected_configs = getattr(self, "expected_configs", []) |
634 | | - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) |
635 | | - configs = " " + configs_str.replace("\n", "\n ") |
636 | | - |
637 | | - # Inputs section |
638 | | - inputs_str = format_inputs_short(self.inputs) |
639 | | - inputs = "Inputs:\n " + inputs_str |
640 | | - |
641 | | - # Intermediates section |
642 | | - intermediates_str = format_intermediates_short( |
643 | | - self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs |
644 | | - ) |
645 | | - intermediates = f"Intermediates:\n{intermediates_str}" |
646 | | - |
647 | | - return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)" |
648 | | - |
649 | | - @property |
650 | | - def doc(self): |
651 | | - return make_doc_string( |
652 | | - self.inputs, |
653 | | - self.intermediate_inputs, |
654 | | - self.outputs, |
655 | | - self.description, |
656 | | - class_name=self.__class__.__name__, |
657 | | - expected_components=self.expected_components, |
658 | | - expected_configs=self.expected_configs, |
659 | | - ) |
660 | | - |
661 | | - def set_block_state(self, state: PipelineState, block_state: BlockState): |
662 | | - for output_param in self.intermediate_outputs: |
663 | | - if not hasattr(block_state, output_param.name): |
664 | | - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") |
665 | | - param = getattr(block_state, output_param.name) |
666 | | - state.set(output_param.name, param, output_param.kwargs_type) |
667 | | - |
668 | | - for input_param in self.intermediate_inputs: |
669 | | - if hasattr(block_state, input_param.name): |
670 | | - param = getattr(block_state, input_param.name) |
671 | | - # Only add if the value is different from what's in the state |
672 | | - current_value = state.get(input_param.name) |
673 | | - if current_value is not param: # Using identity comparison to check if object was modified |
674 | | - state.set(input_param.name, param, input_param.kwargs_type) |
675 | | - |
676 | | - for input_param in self.intermediate_inputs: |
677 | | - if input_param.name and hasattr(block_state, input_param.name): |
678 | | - param = getattr(block_state, input_param.name) |
679 | | - # Only add if the value is different from what's in the state |
680 | | - current_value = state.get(input_param.name) |
681 | | - if current_value is not param: # Using identity comparison to check if object was modified |
682 | | - state.set(input_param.name, param, input_param.kwargs_type) |
683 | | - elif input_param.kwargs_type: |
684 | | - # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters |
685 | | - # we need to first find out which inputs are and loop through them. |
686 | | - intermediate_kwargs = state.get_kwargs(input_param.kwargs_type) |
687 | | - for param_name, current_value in intermediate_kwargs.items(): |
688 | | - param = getattr(block_state, param_name) |
689 | | - if current_value is not param: # Using identity comparison to check if object was modified |
690 | | - state.set(param_name, param, input_param.kwargs_type) |
691 | | - |
692 | | - |
693 | 494 | class AutoPipelineBlocks(ModularPipelineBlocks): |
694 | 495 | """ |
695 | 496 | A Pipeline Blocks that automatically selects a block to run based on the inputs. |
@@ -1042,7 +843,7 @@ def _get_inputs(self): |
1042 | 843 | if inp.name not in outputs and inp.name not in {input.name for input in inputs}: |
1043 | 844 | inputs.append(inp) |
1044 | 845 |
|
1045 | | - # Only add outputs if the block cannot be skipped |
| 846 | + # Only add outputs if the block cannot be skipped |
1046 | 847 | should_add_outputs = True |
1047 | 848 | if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: |
1048 | 849 | should_add_outputs = False |
|
0 commit comments