@@ -262,6 +262,26 @@ class InvocationRegistry:
262
262
@classmethod
263
263
def register_invocation (cls , invocation : type [BaseInvocation ]) -> None :
264
264
"""Registers an invocation."""
265
+
266
+ invocation_type = invocation .get_type ()
267
+ node_pack = invocation .UIConfig .node_pack
268
+
269
+ # Log a warning when an existing invocation is being clobbered by the one we are registering
270
+ clobbered_invocation = InvocationRegistry .get_invocation_for_type (invocation_type )
271
+ if clobbered_invocation is not None :
272
+ # This should always be true - we just checked if the invocation type was in the set
273
+ clobbered_node_pack = clobbered_invocation .UIConfig .node_pack
274
+
275
+ if clobbered_node_pack == "invokeai" :
276
+ # The invocation being clobbered is a core invocation
277
+ logger .warning (f'Overriding core node "{ invocation_type } " with node from "{ node_pack } "' )
278
+ else :
279
+ # The invocation being clobbered is a custom invocation
280
+ logger .warning (
281
+ f'Overriding node "{ invocation_type } " from "{ node_pack } " with node from "{ clobbered_node_pack } "'
282
+ )
283
+ cls ._invocation_classes .remove (clobbered_invocation )
284
+
265
285
cls ._invocation_classes .add (invocation )
266
286
cls .invalidate_invocation_typeadapter ()
267
287
@@ -320,6 +340,15 @@ def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] |
320
340
@classmethod
321
341
def register_output (cls , output : "type[TBaseInvocationOutput]" ) -> None :
322
342
"""Registers an invocation output."""
343
+ output_type = output .get_type ()
344
+
345
+ # Log a warning when an existing invocation is being clobbered by the one we are registering
346
+ clobbered_output = InvocationRegistry .get_output_for_type (output_type )
347
+ if clobbered_output is not None :
348
+ # TODO(psyche): We do not record the node pack of the output, so we cannot log it here
349
+ logger .warning (f'Overriding invocation output "{ output_type } "' )
350
+ cls ._output_classes .remove (clobbered_output )
351
+
323
352
cls ._output_classes .add (output )
324
353
cls .invalidate_output_typeadapter ()
325
354
@@ -328,6 +357,11 @@ def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
328
357
"""Gets all invocation outputs."""
329
358
return cls ._output_classes
330
359
360
+ @classmethod
361
+ def get_outputs_map (cls ) -> dict [str , type [BaseInvocationOutput ]]:
362
+ """Gets a map of all output types to their output classes."""
363
+ return {i .get_type (): i for i in cls .get_output_classes ()}
364
+
331
365
@classmethod
332
366
@lru_cache (maxsize = 1 )
333
367
def get_output_typeadapter (cls ) -> TypeAdapter [Any ]:
@@ -353,6 +387,11 @@ def get_output_types(cls) -> Iterable[str]:
353
387
"""Gets all invocation output types."""
354
388
return (i .get_type () for i in cls .get_output_classes ())
355
389
390
+ @classmethod
391
+ def get_output_for_type (cls , output_type : str ) -> type [BaseInvocationOutput ] | None :
392
+ """Gets the output class for a given output type."""
393
+ return cls .get_outputs_map ().get (output_type )
394
+
356
395
357
396
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
358
397
"id" ,
@@ -466,25 +505,6 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
466
505
# The node pack is the module name - will be "invokeai" for built-in nodes
467
506
node_pack = cls .__module__ .split ("." )[0 ]
468
507
469
- # Handle the case where an existing node is being clobbered by the one we are registering
470
- if invocation_type in InvocationRegistry .get_invocation_types ():
471
- clobbered_invocation = InvocationRegistry .get_invocation_for_type (invocation_type )
472
- # This should always be true - we just checked if the invocation type was in the set
473
- assert clobbered_invocation is not None
474
-
475
- clobbered_node_pack = clobbered_invocation .UIConfig .node_pack
476
-
477
- if clobbered_node_pack == "invokeai" :
478
- # The node being clobbered is a core node
479
- raise ValueError (
480
- f'Cannot load node "{ invocation_type } " from node pack "{ node_pack } " - a core node with the same type already exists'
481
- )
482
- else :
483
- # The node being clobbered is a custom node
484
- raise ValueError (
485
- f'Cannot load node "{ invocation_type } " from node pack "{ node_pack } " - a node with the same type already exists in node pack "{ clobbered_node_pack } "'
486
- )
487
-
488
508
validate_fields (cls .model_fields , invocation_type )
489
509
490
510
# Add OpenAPI schema extras
@@ -578,13 +598,9 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
578
598
if re .compile (r"^\S+$" ).match (output_type ) is None :
579
599
raise ValueError (f'"output_type" must consist of non-whitespace characters, got "{ output_type } "' )
580
600
581
- if output_type in InvocationRegistry .get_output_types ():
582
- raise ValueError (f'Invocation type "{ output_type } " already exists' )
583
-
584
601
validate_fields (cls .model_fields , output_type )
585
602
586
603
# Add the output type to the model.
587
-
588
604
output_type_annotation = Literal [output_type ] # type: ignore
589
605
output_type_field = Field (
590
606
title = "type" , default = output_type , json_schema_extra = {"field_kind" : FieldKind .NodeAttribute }
0 commit comments