Skip to content

Commit cbdafe7

Browse files
psychedeliciousmaryhipp
authored andcommitted
feat(nodes): allow node clobbering
1 parent 112cb76 commit cbdafe7

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,26 @@ class InvocationRegistry:
262262
@classmethod
263263
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
264264
"""Registers an invocation."""
265+
266+
invocation_type = invocation.get_type()
267+
node_pack = invocation.UIConfig.node_pack
268+
269+
# Log a warning when an existing invocation is being clobbered by the one we are registering
270+
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
271+
if clobbered_invocation is not None:
272+
# This should always be true - we just checked if the invocation type was in the set
273+
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
274+
275+
if clobbered_node_pack == "invokeai":
276+
# The invocation being clobbered is a core invocation
277+
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
278+
else:
279+
# The invocation being clobbered is a custom invocation
280+
logger.warning(
281+
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
282+
)
283+
cls._invocation_classes.remove(clobbered_invocation)
284+
265285
cls._invocation_classes.add(invocation)
266286
cls.invalidate_invocation_typeadapter()
267287

@@ -320,6 +340,15 @@ def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] |
320340
@classmethod
321341
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
322342
"""Registers an invocation output."""
343+
output_type = output.get_type()
344+
345+
# Log a warning when an existing invocation is being clobbered by the one we are registering
346+
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
347+
if clobbered_output is not None:
348+
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
349+
logger.warning(f'Overriding invocation output "{output_type}"')
350+
cls._output_classes.remove(clobbered_output)
351+
323352
cls._output_classes.add(output)
324353
cls.invalidate_output_typeadapter()
325354

@@ -328,6 +357,11 @@ def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
328357
"""Gets all invocation outputs."""
329358
return cls._output_classes
330359

360+
@classmethod
361+
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
362+
"""Gets a map of all output types to their output classes."""
363+
return {i.get_type(): i for i in cls.get_output_classes()}
364+
331365
@classmethod
332366
@lru_cache(maxsize=1)
333367
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
@@ -353,6 +387,11 @@ def get_output_types(cls) -> Iterable[str]:
353387
"""Gets all invocation output types."""
354388
return (i.get_type() for i in cls.get_output_classes())
355389

390+
@classmethod
391+
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
392+
"""Gets the output class for a given output type."""
393+
return cls.get_outputs_map().get(output_type)
394+
356395

357396
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
358397
"id",
@@ -466,25 +505,6 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
466505
# The node pack is the module name - will be "invokeai" for built-in nodes
467506
node_pack = cls.__module__.split(".")[0]
468507

469-
# Handle the case where an existing node is being clobbered by the one we are registering
470-
if invocation_type in InvocationRegistry.get_invocation_types():
471-
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
472-
# This should always be true - we just checked if the invocation type was in the set
473-
assert clobbered_invocation is not None
474-
475-
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
476-
477-
if clobbered_node_pack == "invokeai":
478-
# The node being clobbered is a core node
479-
raise ValueError(
480-
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
481-
)
482-
else:
483-
# The node being clobbered is a custom node
484-
raise ValueError(
485-
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
486-
)
487-
488508
validate_fields(cls.model_fields, invocation_type)
489509

490510
# Add OpenAPI schema extras
@@ -578,13 +598,9 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
578598
if re.compile(r"^\S+$").match(output_type) is None:
579599
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
580600

581-
if output_type in InvocationRegistry.get_output_types():
582-
raise ValueError(f'Invocation type "{output_type}" already exists')
583-
584601
validate_fields(cls.model_fields, output_type)
585602

586603
# Add the output type to the model.
587-
588604
output_type_annotation = Literal[output_type] # type: ignore
589605
output_type_field = Field(
590606
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}

0 commit comments

Comments
 (0)