Skip to content

Commit 25f8ab2

Browse files
tests: fix test for breaking pydantic v2.12 change
Fixes a test failure introduced by pydantic/pydantic#11957 TL;DR: "after" model validators should be instance methods, not class methods. Batch model updated to use an instance method, which fixes the failing test.
1 parent c0469ef commit 25f8ab2

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import json
33
from itertools import chain, product
4-
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
4+
from typing import Generator, Literal, Optional, TypeAlias, Union
55

66
from pydantic import (
77
AliasChoices,
@@ -15,7 +15,6 @@
1515
)
1616
from pydantic_core import to_jsonable_python
1717

18-
from invokeai.app.invocations.baseinvocation import BaseInvocation
1918
from invokeai.app.invocations.fields import ImageField
2019
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
2120
from invokeai.app.services.workflow_records.workflow_records_common import (
@@ -137,20 +136,18 @@ def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
137136
return v
138137

139138
@model_validator(mode="after")
140-
def validate_batch_nodes_and_edges(cls, values):
141-
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
142-
if batch_data_collection is None:
143-
return values
144-
graph = cast(Graph, values.graph)
145-
for batch_data_list in batch_data_collection:
139+
def validate_batch_nodes_and_edges(self):
140+
if self.data is None:
141+
return self
142+
for batch_data_list in self.data:
146143
for batch_data in batch_data_list:
147144
try:
148-
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
145+
node = self.graph.get_node(batch_data.node_path)
149146
except NodeNotFoundError:
150147
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
151148
if batch_data.field_name not in type(node).model_fields:
152149
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
153-
return values
150+
return self
154151

155152
@field_validator("graph")
156153
def validate_graph(cls, v: Graph):

0 commit comments

Comments
 (0)