22
33from __future__ import annotations
44
5+ import json
56from abc import ABC , abstractmethod
67from enum import Enum
78from inspect import signature
9+ import re
810from typing import (
911 TYPE_CHECKING ,
1012 AbstractSet ,
1113 Any ,
1214 Callable ,
1315 ClassVar ,
16+ Literal ,
1417 Mapping ,
1518 Optional ,
1619 Type ,
2023 get_type_hints ,
2124)
2225
23- from pydantic import BaseModel , Field
24- from pydantic .fields import Undefined
26+ from pydantic import BaseModel , Field , validator
27+ from pydantic .fields import Undefined , ModelField
2528from pydantic .typing import NoArgAnyCallable
2629
2730if TYPE_CHECKING :
@@ -141,9 +144,11 @@ class UIType(str, Enum):
141144 # endregion
142145
143146 # region Misc
144- FilePath = "FilePath"
145147 Enum = "enum"
146148 Scheduler = "Scheduler"
149+ WorkflowField = "WorkflowField"
150+ IsIntermediate = "IsIntermediate"
151+ MetadataField = "MetadataField"
147152 # endregion
148153
149154
@@ -365,12 +370,12 @@ def OutputField(
365370class UIConfigBase (BaseModel ):
366371 """
367372 Provides additional node configuration to the UI.
368- This is used internally by the @tags and @title decorator logic. You probably want to use those
369- decorators, though you may add this class to a node definition to specify the title and tags.
373+ This is used internally by the @invocation decorator logic. Do not use this directly.
370374 """
371375
372- tags : Optional [list [str ]] = Field (default_factory = None , description = "The tags to display in the UI" )
373- title : Optional [str ] = Field (default = None , description = "The display name of the node" )
376+ tags : Optional [list [str ]] = Field (default_factory = None , description = "The node's tags" )
377+ title : Optional [str ] = Field (default = None , description = "The node's display name" )
378+ category : Optional [str ] = Field (default = None , description = "The node's category" )
374379
375380
376381class InvocationContext :
@@ -383,10 +388,11 @@ def __init__(self, services: InvocationServices, graph_execution_state_id: str):
383388
384389
385390class BaseInvocationOutput (BaseModel ):
386- """Base class for all invocation outputs"""
391+ """
392+ Base class for all invocation outputs.
387393
388- # All outputs must include a type name like this:
389- # type: Literal['your_output_name'] # noqa f821
394+ All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
395+ """
390396
391397 @classmethod
392398 def get_all_subclasses_tuple (cls ):
@@ -422,12 +428,12 @@ def __init__(self, node_id: str, field_name: str):
422428
423429
424430class BaseInvocation (ABC , BaseModel ):
425- """A node to process inputs and produce outputs.
426- May use dependency injection in __init__ to receive providers.
427431 """
432+ A node to process inputs and produce outputs.
433+ May use dependency injection in __init__ to receive providers.
428434
429- # All invocations must include a type name like this:
430- # type: Literal['your_output_name'] # noqa f821
435+ All invocations must use the `@invocation` decorator to provide their unique type.
436+ """
431437
432438 @classmethod
433439 def get_all_subclasses (cls ):
@@ -466,6 +472,8 @@ def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
466472 schema ["title" ] = uiconfig .title
467473 if uiconfig and hasattr (uiconfig , "tags" ):
468474 schema ["tags" ] = uiconfig .tags
475+ if uiconfig and hasattr (uiconfig , "category" ):
476+ schema ["category" ] = uiconfig .category
469477 if "required" not in schema or not isinstance (schema ["required" ], list ):
470478 schema ["required" ] = list ()
471479 schema ["required" ].extend (["type" , "id" ])
@@ -505,37 +513,110 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
505513 raise MissingInputException (self .__fields__ ["type" ].default , field_name )
506514 return self .invoke (context )
507515
508- id : str = Field (description = "The id of this node. Must be unique among all nodes." )
516+ id : str = Field (
517+ description = "The id of this instance of an invocation. Must be unique among all instances of invocations."
518+ )
509519 is_intermediate : bool = InputField (
510- default = False , description = "Whether or not this node is an intermediate node." , input = Input .Direct
520+ default = False , description = "Whether or not this is an intermediate invocation." , ui_type = UIType .IsIntermediate
521+ )
522+ workflow : Optional [str ] = InputField (
523+ default = None ,
524+ description = "The workflow to save with the image" ,
525+ ui_type = UIType .WorkflowField ,
511526 )
527+
528+ @validator ("workflow" , pre = True )
529+ def validate_workflow_is_json (cls , v ):
530+ if v is None :
531+ return None
532+ try :
533+ json .loads (v )
534+ except json .decoder .JSONDecodeError :
535+ raise ValueError ("Workflow must be valid JSON" )
536+ return v
537+
512538 UIConfig : ClassVar [Type [UIConfigBase ]]
513539
514540
515- T = TypeVar ("T" , bound = BaseInvocation )
541+ GenericBaseInvocation = TypeVar ("GenericBaseInvocation" , bound = BaseInvocation )
542+
543+
544+ def invocation (
545+ invocation_type : str , title : Optional [str ] = None , tags : Optional [list [str ]] = None , category : Optional [str ] = None
546+ ) -> Callable [[Type [GenericBaseInvocation ]], Type [GenericBaseInvocation ]]:
547+ """
548+ Adds metadata to an invocation.
516549
550+ :param str invocation_type: The type of the invocation. Must be unique among all invocations.
551+ :param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
552+ :param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
553+ :param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
554+ """
517555
518- def title (title : str ) -> Callable [[Type [T ]], Type [T ]]:
519- """Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
556+ def wrapper (cls : Type [GenericBaseInvocation ]) -> Type [GenericBaseInvocation ]:
557+ # Validate invocation types on creation of invocation classes
558+ # TODO: ensure unique?
559+ if re .compile (r"^\S+$" ).match (invocation_type ) is None :
560+ raise ValueError (f'"invocation_type" must consist of non-whitespace characters, got "{ invocation_type } "' )
520561
521- def wrapper ( cls : Type [ T ]) -> Type [ T ]:
562+ # Add OpenAPI schema extras
522563 uiconf_name = cls .__qualname__ + ".UIConfig"
523564 if not hasattr (cls , "UIConfig" ) or cls .UIConfig .__qualname__ != uiconf_name :
524565 cls .UIConfig = type (uiconf_name , (UIConfigBase ,), dict ())
525- cls .UIConfig .title = title
566+ if title is not None :
567+ cls .UIConfig .title = title
568+ if tags is not None :
569+ cls .UIConfig .tags = tags
570+ if category is not None :
571+ cls .UIConfig .category = category
572+
573+ # Add the invocation type to the pydantic model of the invocation
574+ invocation_type_annotation = Literal [invocation_type ] # type: ignore
575+ invocation_type_field = ModelField .infer (
576+ name = "type" ,
577+ value = invocation_type ,
578+ annotation = invocation_type_annotation ,
579+ class_validators = None ,
580+ config = cls .__config__ ,
581+ )
582+ cls .__fields__ .update ({"type" : invocation_type_field })
583+ cls .__annotations__ .update ({"type" : invocation_type_annotation })
584+
526585 return cls
527586
528587 return wrapper
529588
530589
531- def tags (* tags : str ) -> Callable [[Type [T ]], Type [T ]]:
532- """Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
590+ GenericBaseInvocationOutput = TypeVar ("GenericBaseInvocationOutput" , bound = BaseInvocationOutput )
591+
592+
593+ def invocation_output (
594+ output_type : str ,
595+ ) -> Callable [[Type [GenericBaseInvocationOutput ]], Type [GenericBaseInvocationOutput ]]:
596+ """
597+ Adds metadata to an invocation output.
598+
599+ :param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
600+ """
601+
602+ def wrapper (cls : Type [GenericBaseInvocationOutput ]) -> Type [GenericBaseInvocationOutput ]:
603+ # Validate output types on creation of invocation output classes
604+ # TODO: ensure unique?
605+ if re .compile (r"^\S+$" ).match (output_type ) is None :
606+ raise ValueError (f'"output_type" must consist of non-whitespace characters, got "{ output_type } "' )
607+
608+ # Add the output type to the pydantic model of the invocation output
609+ output_type_annotation = Literal [output_type ] # type: ignore
610+ output_type_field = ModelField .infer (
611+ name = "type" ,
612+ value = output_type ,
613+ annotation = output_type_annotation ,
614+ class_validators = None ,
615+ config = cls .__config__ ,
616+ )
617+ cls .__fields__ .update ({"type" : output_type_field })
618+ cls .__annotations__ .update ({"type" : output_type_annotation })
533619
534- def wrapper (cls : Type [T ]) -> Type [T ]:
535- uiconf_name = cls .__qualname__ + ".UIConfig"
536- if not hasattr (cls , "UIConfig" ) or cls .UIConfig .__qualname__ != uiconf_name :
537- cls .UIConfig = type (uiconf_name , (UIConfigBase ,), dict ())
538- cls .UIConfig .tags = list (tags )
539620 return cls
540621
541622 return wrapper
0 commit comments