Skip to content

Commit 1f6c868

Browse files
feat(nodes): polymorphic fields (#4423)
## What type of PR is this? (check all applicable) - [ ] Refactor - [x] Feature - [x] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Description ### Polymorphic Fields Initial support for polymorphic field types. Polymorphic types are a single of or list of a specific type. For example, `Union[str, list[str]]`. Polymorphics do not yet have support for direct input in the UI (will come in the future). They will be forcibly set as Connection-only fields, in which case users will not be able to provide direct input to the field. If a polymorphic should present as a singleton type - which would allow direct input - the node must provide an explicit type hint. For example, `DenoiseLatents`' `CFG Scale` is polymorphic, but in the node editor, we want to present this as a number input. In the node definition, the field is given `ui_type=UIType.Float`, which tells the UI to treat this as a `float` field. The connection validation logic will prevent connecting a collection to `CFG Scale` in this situation, because it is typed as `float`. The workaround is to disable validation from the settings to make this specific connection. A future improvement will resolve this. ### Collection Fields This also introduces better support for collection field types. Like polymorphics, collection types are parsed automatically by the client and do not need any specific type hints. Also like polymorphics, there is no support yet for direct input of collection types in the UI. ### Other Changes - Disabling validation in workflow editor now displays the visual hints for valid connections, but lets you connect to anything. - Added `ui_order: int` to `InputField` and `OutputField`. The UI will use this, if present, to order fields in a node UI. See usage in `DenoiseLatents` for an example. - Updated the field colors - duplicate colors have just been lightened a bit. It's not perfect but it was a quick fix. - Field handles for collections are the same color as their single counterparts, but have a dark dot in the center of them. - Field handles for polymorphics are a rounded square with dot in the middle. - Removed all fields that just render `null` from `InputFieldRenderer`, replaced with a single fallback - Removed logic in `zValidatedWorkflow`, which checked for existence of node templates for each node in a workflow. This logic introduced a circular dependency, due to importing the global redux `store` in order to get the node templates within a zod schema. It's actually fine to just leave this out entirely; The case of a missing node template is handled by the UI. Fixing it otherwise would introduce a substantial headache. - Fixed the `ControlNetInvocation.control_model` field default, which was a string when it shouldn't have one. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Closes #4266 ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> Add this polymorphic float node to the end of your `invokeai/app/invocations/primitives.py`: ```py @invocation("float_poly", title="Float Poly Test", tags=["primitives", "float"], category="primitives") class FloatPolyInvocation(BaseInvocation): """A float polymorphic primitive value""" value: Union[float, list[float]] = InputField(default_factory=list, description="The float value") def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=self.value[0] if isinstance(self.value, list) else self.value) `` Head over to nodes and try to connecting up some collection and polymorphic inputs.
2 parents d69f3a0 + 59cb630 commit 1f6c868

File tree

23 files changed

+1470
-700
lines changed

23 files changed

+1470
-700
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,24 +105,39 @@ class UIType(str, Enum):
105105
"""
106106

107107
# region Primitives
108-
Integer = "integer"
109-
Float = "float"
110108
Boolean = "boolean"
111-
String = "string"
112-
Array = "array"
113-
Image = "ImageField"
114-
Latents = "LatentsField"
109+
Color = "ColorField"
115110
Conditioning = "ConditioningField"
116111
Control = "ControlField"
117-
Color = "ColorField"
118-
ImageCollection = "ImageCollection"
119-
ConditioningCollection = "ConditioningCollection"
112+
Float = "float"
113+
Image = "ImageField"
114+
Integer = "integer"
115+
Latents = "LatentsField"
116+
String = "string"
117+
# endregion
118+
119+
# region Collection Primitives
120+
BooleanCollection = "BooleanCollection"
120121
ColorCollection = "ColorCollection"
121-
LatentsCollection = "LatentsCollection"
122-
IntegerCollection = "IntegerCollection"
122+
ConditioningCollection = "ConditioningCollection"
123+
ControlCollection = "ControlCollection"
123124
FloatCollection = "FloatCollection"
125+
ImageCollection = "ImageCollection"
126+
IntegerCollection = "IntegerCollection"
127+
LatentsCollection = "LatentsCollection"
124128
StringCollection = "StringCollection"
125-
BooleanCollection = "BooleanCollection"
129+
# endregion
130+
131+
# region Polymorphic Primitives
132+
BooleanPolymorphic = "BooleanPolymorphic"
133+
ColorPolymorphic = "ColorPolymorphic"
134+
ConditioningPolymorphic = "ConditioningPolymorphic"
135+
ControlPolymorphic = "ControlPolymorphic"
136+
FloatPolymorphic = "FloatPolymorphic"
137+
ImagePolymorphic = "ImagePolymorphic"
138+
IntegerPolymorphic = "IntegerPolymorphic"
139+
LatentsPolymorphic = "LatentsPolymorphic"
140+
StringPolymorphic = "StringPolymorphic"
126141
# endregion
127142

128143
# region Models
@@ -176,6 +191,7 @@ class _InputField(BaseModel):
176191
ui_type: Optional[UIType]
177192
ui_component: Optional[UIComponent]
178193
ui_order: Optional[int]
194+
item_default: Optional[Any]
179195

180196

181197
class _OutputField(BaseModel):
@@ -223,6 +239,7 @@ def InputField(
223239
ui_component: Optional[UIComponent] = None,
224240
ui_hidden: bool = False,
225241
ui_order: Optional[int] = None,
242+
item_default: Optional[Any] = None,
226243
**kwargs: Any,
227244
) -> Any:
228245
"""
@@ -249,6 +266,11 @@ def InputField(
249266
For this case, you could provide `UIComponent.Textarea`.
250267
251268
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
269+
270+
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
271+
272+
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
273+
Ignored for non-collection fields..
252274
"""
253275
return Field(
254276
*args,
@@ -282,6 +304,7 @@ def InputField(
282304
ui_component=ui_component,
283305
ui_hidden=ui_hidden,
284306
ui_order=ui_order,
307+
item_default=item_default,
285308
**kwargs,
286309
)
287310

@@ -332,6 +355,8 @@ def OutputField(
332355
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
333356
334357
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
358+
359+
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
335360
"""
336361
return Field(
337362
*args,

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ class ControlNetInvocation(BaseInvocation):
100100
"""Collects ControlNet info to pass to other nodes"""
101101

102102
image: ImageField = InputField(description="The control image")
103-
control_model: ControlNetModelField = InputField(
104-
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
105-
)
103+
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
106104
control_weight: Union[float, List[float]] = InputField(
107105
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
108106
)

invokeai/app/invocations/latent.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
208208
)
209209
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
210210
control: Union[ControlField, list[ControlField]] = InputField(
211-
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
211+
default=None,
212+
description=FieldDescriptions.control,
213+
input=Input.Connection,
214+
ui_order=5,
212215
)
213216
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
214217
denoise_mask: Optional[DenoiseMaskField] = InputField(
215-
default=None,
216-
description=FieldDescriptions.mask,
218+
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
217219
)
218220

219221
@validator("cfg_scale")
@@ -317,7 +319,7 @@ def prep_control_data(
317319
context: InvocationContext,
318320
# really only need model for dtype and device
319321
model: StableDiffusionGeneratorPipeline,
320-
control_input: List[ControlField],
322+
control_input: Union[ControlField, List[ControlField]],
321323
latents_shape: List[int],
322324
exit_stack: ExitStack,
323325
do_classifier_free_guidance: bool = True,

invokeai/app/invocations/primitives.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
InvocationContext,
1515
OutputField,
1616
UIComponent,
17-
UIType,
1817
invocation,
1918
invocation_output,
2019
)
@@ -40,7 +39,9 @@ class BooleanOutput(BaseInvocationOutput):
4039
class BooleanCollectionOutput(BaseInvocationOutput):
4140
"""Base class for nodes that output a collection of booleans"""
4241

43-
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
42+
collection: list[bool] = OutputField(
43+
description="The output boolean collection",
44+
)
4445

4546

4647
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
@@ -62,9 +63,7 @@ def invoke(self, context: InvocationContext) -> BooleanOutput:
6263
class BooleanCollectionInvocation(BaseInvocation):
6364
"""A collection of boolean primitive values"""
6465

65-
collection: list[bool] = InputField(
66-
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
67-
)
66+
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
6867

6968
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
7069
return BooleanCollectionOutput(collection=self.collection)
@@ -86,7 +85,9 @@ class IntegerOutput(BaseInvocationOutput):
8685
class IntegerCollectionOutput(BaseInvocationOutput):
8786
"""Base class for nodes that output a collection of integers"""
8887

89-
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
88+
collection: list[int] = OutputField(
89+
description="The int collection",
90+
)
9091

9192

9293
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
@@ -108,9 +109,7 @@ def invoke(self, context: InvocationContext) -> IntegerOutput:
108109
class IntegerCollectionInvocation(BaseInvocation):
109110
"""A collection of integer primitive values"""
110111

111-
collection: list[int] = InputField(
112-
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
113-
)
112+
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
114113

115114
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
116115
return IntegerCollectionOutput(collection=self.collection)
@@ -132,7 +131,9 @@ class FloatOutput(BaseInvocationOutput):
132131
class FloatCollectionOutput(BaseInvocationOutput):
133132
"""Base class for nodes that output a collection of floats"""
134133

135-
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
134+
collection: list[float] = OutputField(
135+
description="The float collection",
136+
)
136137

137138

138139
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
@@ -154,9 +155,7 @@ def invoke(self, context: InvocationContext) -> FloatOutput:
154155
class FloatCollectionInvocation(BaseInvocation):
155156
"""A collection of float primitive values"""
156157

157-
collection: list[float] = InputField(
158-
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
159-
)
158+
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
160159

161160
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
162161
return FloatCollectionOutput(collection=self.collection)
@@ -178,7 +177,9 @@ class StringOutput(BaseInvocationOutput):
178177
class StringCollectionOutput(BaseInvocationOutput):
179178
"""Base class for nodes that output a collection of strings"""
180179

181-
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
180+
collection: list[str] = OutputField(
181+
description="The output strings",
182+
)
182183

183184

184185
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
@@ -200,9 +201,7 @@ def invoke(self, context: InvocationContext) -> StringOutput:
200201
class StringCollectionInvocation(BaseInvocation):
201202
"""A collection of string primitive values"""
202203

203-
collection: list[str] = InputField(
204-
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
205-
)
204+
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
206205

207206
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
208207
return StringCollectionOutput(collection=self.collection)
@@ -232,7 +231,9 @@ class ImageOutput(BaseInvocationOutput):
232231
class ImageCollectionOutput(BaseInvocationOutput):
233232
"""Base class for nodes that output a collection of images"""
234233

235-
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
234+
collection: list[ImageField] = OutputField(
235+
description="The output images",
236+
)
236237

237238

238239
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
@@ -260,9 +261,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
260261
class ImageCollectionInvocation(BaseInvocation):
261262
"""A collection of image primitive values"""
262263

263-
collection: list[ImageField] = InputField(
264-
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
265-
)
264+
collection: list[ImageField] = InputField(description="The collection of image values")
266265

267266
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
268267
return ImageCollectionOutput(collection=self.collection)
@@ -316,7 +315,6 @@ class LatentsCollectionOutput(BaseInvocationOutput):
316315

317316
collection: list[LatentsField] = OutputField(
318317
description=FieldDescriptions.latents,
319-
ui_type=UIType.LatentsCollection,
320318
)
321319

322320

@@ -342,7 +340,7 @@ class LatentsCollectionInvocation(BaseInvocation):
342340
"""A collection of latents tensor primitive values"""
343341

344342
collection: list[LatentsField] = InputField(
345-
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
343+
description="The collection of latents tensors",
346344
)
347345

348346
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
@@ -385,7 +383,9 @@ class ColorOutput(BaseInvocationOutput):
385383
class ColorCollectionOutput(BaseInvocationOutput):
386384
"""Base class for nodes that output a collection of colors"""
387385

388-
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
386+
collection: list[ColorField] = OutputField(
387+
description="The output colors",
388+
)
389389

390390

391391
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
@@ -422,7 +422,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
422422

423423
collection: list[ConditioningField] = OutputField(
424424
description="The output conditioning tensors",
425-
ui_type=UIType.ConditioningCollection,
426425
)
427426

428427

@@ -453,7 +452,6 @@ class ConditioningCollectionInvocation(BaseInvocation):
453452
collection: list[ConditioningField] = InputField(
454453
default_factory=list,
455454
description="The collection of conditioning tensors",
456-
ui_type=UIType.ConditioningCollection,
457455
)
458456

459457
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:

invokeai/app/services/graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
112112
if to_type in get_args(from_type):
113113
return True
114114

115+
# allow int -> float, pydantic will cast for us
116+
if from_type is int and to_type is float:
117+
return True
118+
115119
# if not issubclass(from_type, to_type):
116120
if not is_union_subtype(from_type, to_type):
117121
return False

invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ const selector = createSelector(
6363
return;
6464
}
6565

66-
if (fieldTemplate.required && !field.value && !hasConnection) {
66+
if (
67+
fieldTemplate.required &&
68+
field.value === undefined &&
69+
!hasConnection
70+
) {
6771
reasons.push(
6872
`${node.data.label || nodeTemplate.title} -> ${
6973
field.label || fieldTemplate.title
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
export const colorTokenToCssVar = (colorToken: string) =>
2-
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
2+
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import { Tooltip } from '@chakra-ui/react';
22
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
33
import {
4+
COLLECTION_TYPES,
45
FIELDS,
56
HANDLE_TOOLTIP_OPEN_DELAY,
7+
MODEL_TYPES,
8+
POLYMORPHIC_TYPES,
69
} from 'features/nodes/types/constants';
710
import {
811
InputFieldTemplate,
@@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
1821
borderWidth: 0,
1922
zIndex: 1,
2023
};
24+
``;
2125

2226
export const inputHandleStyles: CSSProperties = {
2327
left: '-1rem',
@@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
4448
connectionError,
4549
} = props;
4650
const { name, type } = fieldTemplate;
47-
const { color, title } = FIELDS[type];
51+
const { color: typeColor, title } = FIELDS[type];
4852

4953
const styles: CSSProperties = useMemo(() => {
54+
const isCollectionType = COLLECTION_TYPES.includes(type);
55+
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
56+
const isModelType = MODEL_TYPES.includes(type);
57+
const color = colorTokenToCssVar(typeColor);
5058
const s: CSSProperties = {
51-
backgroundColor: colorTokenToCssVar(color),
59+
backgroundColor:
60+
isCollectionType || isPolymorphicType
61+
? 'var(--invokeai-colors-base-900)'
62+
: color,
5263
position: 'absolute',
5364
width: '1rem',
5465
height: '1rem',
55-
borderWidth: 0,
66+
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
67+
borderStyle: 'solid',
68+
borderColor: color,
69+
borderRadius: isModelType ? 4 : '100%',
5670
zIndex: 1,
5771
};
5872

@@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
7892

7993
return s;
8094
}, [
81-
color,
8295
connectionError,
8396
handleType,
8497
isConnectionInProgress,
8598
isConnectionStartField,
99+
type,
100+
typeColor,
86101
]);
87102

88103
const tooltip = useMemo(() => {

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
7575
sx={{
7676
display: 'flex',
7777
alignItems: 'center',
78+
h: 'full',
7879
mb: 0,
7980
px: 1,
8081
gap: 2,

0 commit comments

Comments
 (0)