Skip to content

Commit 2bd3cf2

Browse files
nodes phase 5: workflow saving and loading (#4353)
## What type of PR is this? (check all applicable) - [ ] Refactor - [x] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Description - Workflows are saved to image files directly - Image-outputting nodes have an `Embed Workflow` checkbox which, if enabled, saves the workflow - `BaseInvocation` now has an `workflow: Optional[str]` field, so all nodes automatically have the field (but again only image-outputting nodes display this in UI) - If this field is enabled, when the graph is created, the workflow is stringified and set in this field - Nodes should add `workflow=self.workflow` when they save their output image to have the workflow written to the image - Uploads now have their metadata retained so that you can upload somebody else's image and have access to that workflow - Graphs are no longer saved to images, workflows replace them ### TODO - Images created in the linear UI do not have a workflow saved yet. Need to write a function to build a workflow around the linear UI graph when using linear tabs. Unfortunately it will not have the nice positioning and size data the node editor gives you when you save a workflow... we'll have to figure out how to handle this. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Related Issue # - Closes # ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. -->
2 parents 4405c39 + 3cd2d3b commit 2bd3cf2

File tree

103 files changed

+3404
-2699
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+3404
-2699
lines changed

docs/contributing/INVOCATIONS.md

Lines changed: 121 additions & 518 deletions
Large diffs are not rendered by default.

invokeai/app/invocations/baseinvocation.py

Lines changed: 109 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
from __future__ import annotations
44

5+
import json
56
from abc import ABC, abstractmethod
67
from enum import Enum
78
from inspect import signature
9+
import re
810
from typing import (
911
TYPE_CHECKING,
1012
AbstractSet,
1113
Any,
1214
Callable,
1315
ClassVar,
16+
Literal,
1417
Mapping,
1518
Optional,
1619
Type,
@@ -20,8 +23,8 @@
2023
get_type_hints,
2124
)
2225

23-
from pydantic import BaseModel, Field
24-
from pydantic.fields import Undefined
26+
from pydantic import BaseModel, Field, validator
27+
from pydantic.fields import Undefined, ModelField
2528
from pydantic.typing import NoArgAnyCallable
2629

2730
if TYPE_CHECKING:
@@ -141,9 +144,11 @@ class UIType(str, Enum):
141144
# endregion
142145

143146
# region Misc
144-
FilePath = "FilePath"
145147
Enum = "enum"
146148
Scheduler = "Scheduler"
149+
WorkflowField = "WorkflowField"
150+
IsIntermediate = "IsIntermediate"
151+
MetadataField = "MetadataField"
147152
# endregion
148153

149154

@@ -365,12 +370,12 @@ def OutputField(
365370
class UIConfigBase(BaseModel):
366371
"""
367372
Provides additional node configuration to the UI.
368-
This is used internally by the @tags and @title decorator logic. You probably want to use those
369-
decorators, though you may add this class to a node definition to specify the title and tags.
373+
This is used internally by the @invocation decorator logic. Do not use this directly.
370374
"""
371375

372-
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
373-
title: Optional[str] = Field(default=None, description="The display name of the node")
376+
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
377+
title: Optional[str] = Field(default=None, description="The node's display name")
378+
category: Optional[str] = Field(default=None, description="The node's category")
374379

375380

376381
class InvocationContext:
@@ -383,10 +388,11 @@ def __init__(self, services: InvocationServices, graph_execution_state_id: str):
383388

384389

385390
class BaseInvocationOutput(BaseModel):
386-
"""Base class for all invocation outputs"""
391+
"""
392+
Base class for all invocation outputs.
387393
388-
# All outputs must include a type name like this:
389-
# type: Literal['your_output_name'] # noqa f821
394+
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
395+
"""
390396

391397
@classmethod
392398
def get_all_subclasses_tuple(cls):
@@ -422,12 +428,12 @@ def __init__(self, node_id: str, field_name: str):
422428

423429

424430
class BaseInvocation(ABC, BaseModel):
425-
"""A node to process inputs and produce outputs.
426-
May use dependency injection in __init__ to receive providers.
427431
"""
432+
A node to process inputs and produce outputs.
433+
May use dependency injection in __init__ to receive providers.
428434
429-
# All invocations must include a type name like this:
430-
# type: Literal['your_output_name'] # noqa f821
435+
All invocations must use the `@invocation` decorator to provide their unique type.
436+
"""
431437

432438
@classmethod
433439
def get_all_subclasses(cls):
@@ -466,6 +472,8 @@ def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
466472
schema["title"] = uiconfig.title
467473
if uiconfig and hasattr(uiconfig, "tags"):
468474
schema["tags"] = uiconfig.tags
475+
if uiconfig and hasattr(uiconfig, "category"):
476+
schema["category"] = uiconfig.category
469477
if "required" not in schema or not isinstance(schema["required"], list):
470478
schema["required"] = list()
471479
schema["required"].extend(["type", "id"])
@@ -505,37 +513,110 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
505513
raise MissingInputException(self.__fields__["type"].default, field_name)
506514
return self.invoke(context)
507515

508-
id: str = Field(description="The id of this node. Must be unique among all nodes.")
516+
id: str = Field(
517+
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
518+
)
509519
is_intermediate: bool = InputField(
510-
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
520+
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
521+
)
522+
workflow: Optional[str] = InputField(
523+
default=None,
524+
description="The workflow to save with the image",
525+
ui_type=UIType.WorkflowField,
511526
)
527+
528+
@validator("workflow", pre=True)
529+
def validate_workflow_is_json(cls, v):
530+
if v is None:
531+
return None
532+
try:
533+
json.loads(v)
534+
except json.decoder.JSONDecodeError:
535+
raise ValueError("Workflow must be valid JSON")
536+
return v
537+
512538
UIConfig: ClassVar[Type[UIConfigBase]]
513539

514540

515-
T = TypeVar("T", bound=BaseInvocation)
541+
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
542+
543+
544+
def invocation(
545+
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
546+
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
547+
"""
548+
Adds metadata to an invocation.
516549
550+
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
551+
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
552+
:param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
553+
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
554+
"""
517555

518-
def title(title: str) -> Callable[[Type[T]], Type[T]]:
519-
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
556+
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
557+
# Validate invocation types on creation of invocation classes
558+
# TODO: ensure unique?
559+
if re.compile(r"^\S+$").match(invocation_type) is None:
560+
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
520561

521-
def wrapper(cls: Type[T]) -> Type[T]:
562+
# Add OpenAPI schema extras
522563
uiconf_name = cls.__qualname__ + ".UIConfig"
523564
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
524565
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
525-
cls.UIConfig.title = title
566+
if title is not None:
567+
cls.UIConfig.title = title
568+
if tags is not None:
569+
cls.UIConfig.tags = tags
570+
if category is not None:
571+
cls.UIConfig.category = category
572+
573+
# Add the invocation type to the pydantic model of the invocation
574+
invocation_type_annotation = Literal[invocation_type] # type: ignore
575+
invocation_type_field = ModelField.infer(
576+
name="type",
577+
value=invocation_type,
578+
annotation=invocation_type_annotation,
579+
class_validators=None,
580+
config=cls.__config__,
581+
)
582+
cls.__fields__.update({"type": invocation_type_field})
583+
cls.__annotations__.update({"type": invocation_type_annotation})
584+
526585
return cls
527586

528587
return wrapper
529588

530589

531-
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
532-
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
590+
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
591+
592+
593+
def invocation_output(
594+
output_type: str,
595+
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
596+
"""
597+
Adds metadata to an invocation output.
598+
599+
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
600+
"""
601+
602+
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
603+
# Validate output types on creation of invocation output classes
604+
# TODO: ensure unique?
605+
if re.compile(r"^\S+$").match(output_type) is None:
606+
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
607+
608+
# Add the output type to the pydantic model of the invocation output
609+
output_type_annotation = Literal[output_type] # type: ignore
610+
output_type_field = ModelField.infer(
611+
name="type",
612+
value=output_type,
613+
annotation=output_type_annotation,
614+
class_validators=None,
615+
config=cls.__config__,
616+
)
617+
cls.__fields__.update({"type": output_type_field})
618+
cls.__annotations__.update({"type": output_type_annotation})
533619

534-
def wrapper(cls: Type[T]) -> Type[T]:
535-
uiconf_name = cls.__qualname__ + ".UIConfig"
536-
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
537-
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
538-
cls.UIConfig.tags = list(tags)
539620
return cls
540621

541622
return wrapper

invokeai/app/invocations/collections.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
11
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
22

3-
from typing import Literal
43

54
import numpy as np
65
from pydantic import validator
76

87
from invokeai.app.invocations.primitives import IntegerCollectionOutput
98
from invokeai.app.util.misc import SEED_MAX, get_random_seed
109

11-
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
10+
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
1211

1312

14-
@title("Integer Range")
15-
@tags("collection", "integer", "range")
13+
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
1614
class RangeInvocation(BaseInvocation):
1715
"""Creates a range of numbers from start to stop with step"""
1816

19-
type: Literal["range"] = "range"
20-
21-
# Inputs
2217
start: int = InputField(default=0, description="The start of the range")
2318
stop: int = InputField(default=10, description="The stop of the range")
2419
step: int = InputField(default=1, description="The step of the range")
@@ -33,14 +28,15 @@ def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
3328
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
3429

3530

36-
@title("Integer Range of Size")
37-
@tags("range", "integer", "size", "collection")
31+
@invocation(
32+
"range_of_size",
33+
title="Integer Range of Size",
34+
tags=["collection", "integer", "size", "range"],
35+
category="collections",
36+
)
3837
class RangeOfSizeInvocation(BaseInvocation):
3938
"""Creates a range from start to start + size with step"""
4039

41-
type: Literal["range_of_size"] = "range_of_size"
42-
43-
# Inputs
4440
start: int = InputField(default=0, description="The start of the range")
4541
size: int = InputField(default=1, description="The number of values")
4642
step: int = InputField(default=1, description="The step of the range")
@@ -49,14 +45,15 @@ def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
4945
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
5046

5147

52-
@title("Random Range")
53-
@tags("range", "integer", "random", "collection")
48+
@invocation(
49+
"random_range",
50+
title="Random Range",
51+
tags=["range", "integer", "random", "collection"],
52+
category="collections",
53+
)
5454
class RandomRangeInvocation(BaseInvocation):
5555
"""Creates a collection of random numbers"""
5656

57-
type: Literal["random_range"] = "random_range"
58-
59-
# Inputs
6057
low: int = InputField(default=0, description="The inclusive low value")
6158
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
6259
size: int = InputField(default=1, description="The number of values to generate")

invokeai/app/invocations/compel.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from dataclasses import dataclass
3-
from typing import List, Literal, Union
3+
from typing import List, Union
44

55
import torch
66
from compel import Compel, ReturnedEmbeddingsType
@@ -26,8 +26,8 @@
2626
InvocationContext,
2727
OutputField,
2828
UIComponent,
29-
tags,
30-
title,
29+
invocation,
30+
invocation_output,
3131
)
3232
from .model import ClipField
3333

@@ -44,13 +44,10 @@ class ConditioningFieldData:
4444
# PerpNeg = "perp_neg"
4545

4646

47-
@title("Compel Prompt")
48-
@tags("prompt", "compel")
47+
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
4948
class CompelInvocation(BaseInvocation):
5049
"""Parse prompt using compel package to conditioning."""
5150

52-
type: Literal["compel"] = "compel"
53-
5451
prompt: str = InputField(
5552
default="",
5653
description=FieldDescriptions.compel_prompt,
@@ -265,13 +262,15 @@ def _lora_loader():
265262
return c, c_pooled, ec
266263

267264

268-
@title("SDXL Compel Prompt")
269-
@tags("sdxl", "compel", "prompt")
265+
@invocation(
266+
"sdxl_compel_prompt",
267+
title="SDXL Prompt",
268+
tags=["sdxl", "compel", "prompt"],
269+
category="conditioning",
270+
)
270271
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
271272
"""Parse prompt using compel package to conditioning."""
272273

273-
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
274-
275274
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
276275
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
277276
original_width: int = InputField(default=1024, description="")
@@ -324,13 +323,15 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
324323
)
325324

326325

327-
@title("SDXL Refiner Compel Prompt")
328-
@tags("sdxl", "compel", "prompt")
326+
@invocation(
327+
"sdxl_refiner_compel_prompt",
328+
title="SDXL Refiner Prompt",
329+
tags=["sdxl", "compel", "prompt"],
330+
category="conditioning",
331+
)
329332
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
330333
"""Parse prompt using compel package to conditioning."""
331334

332-
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
333-
334335
style: str = InputField(
335336
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
336337
) # TODO: ?
@@ -372,20 +373,17 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
372373
)
373374

374375

376+
@invocation_output("clip_skip_output")
375377
class ClipSkipInvocationOutput(BaseInvocationOutput):
376378
"""Clip skip node output"""
377379

378-
type: Literal["clip_skip_output"] = "clip_skip_output"
379380
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
380381

381382

382-
@title("CLIP Skip")
383-
@tags("clipskip", "clip", "skip")
383+
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
384384
class ClipSkipInvocation(BaseInvocation):
385385
"""Skip layers in clip text_encoder model."""
386386

387-
type: Literal["clip_skip"] = "clip_skip"
388-
389387
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
390388
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
391389

0 commit comments

Comments
 (0)