Skip to content

Commit d20f98f

Browse files
fix(nodes): deep copy graph inputs
The change to memory session storage brings a subtle behaviour change. Previously, we serialized and deserialized everything (e.g. field state, invocation outputs, etc) constantly. The meant we were effectively working with deep-copied objects at all time. We could mutate objects freely without worrying about other references to the object. With memory storage, objects are now passed around by reference, and we cannot handle them in the same way. This is problematic for nodes that mutate their own inputs. There are two ways this causes a problem: - An output is used as input for multiple nodes. If the first node mutates the output object while `invoke`ing, the next node will get the mutated object. - The invocation cache stores live python objects. When a node mutates an output pulled from the cache, the next node that uses the cached object will get the mutated object. The solution is to deep-copy a node's inputs as they are set, effectively reproducing the same behaviour as we had with the SQLite session storage. Nodes can safely mutate their inputs and those changes never leave the node's scope. Closes #5665
1 parent c9c150f commit d20f98f

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

invokeai/app/services/shared/graph.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import copy
44
import itertools
5-
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
5+
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
66

77
import networkx as nx
88
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
@@ -141,6 +141,16 @@ def are_connections_compatible(
141141
return are_connection_types_compatible(from_node_field, to_node_field)
142142

143143

144+
T = TypeVar("T")
145+
146+
147+
def copydeep(obj: T) -> T:
148+
"""Deep-copies an object. If it is a pydantic model, use the model's copy method."""
149+
if isinstance(obj, BaseModel):
150+
return obj.model_copy(deep=True)
151+
return copy.deepcopy(obj)
152+
153+
144154
class NodeAlreadyInGraphError(ValueError):
145155
pass
146156

@@ -1118,17 +1128,22 @@ def _get_next_node(self) -> Optional[BaseInvocation]:
11181128

11191129
def _prepare_inputs(self, node: BaseInvocation):
11201130
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
1131+
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
1132+
# will see the mutation.
11211133
if isinstance(node, CollectInvocation):
11221134
output_collection = [
1123-
getattr(self.results[edge.source.node_id], edge.source.field)
1135+
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
11241136
for edge in input_edges
11251137
if edge.destination.field == "item"
11261138
]
11271139
node.collection = output_collection
11281140
else:
11291141
for edge in input_edges:
1130-
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
1131-
setattr(node, edge.destination.field, output_value)
1142+
setattr(
1143+
node,
1144+
edge.destination.field,
1145+
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
1146+
)
11321147

11331148
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
11341149
def _is_edge_valid(self, edge: Edge) -> bool:

0 commit comments

Comments
 (0)