|
1 | 1 | import datetime
|
2 | 2 | import json
|
3 | 3 | from itertools import chain, product
|
4 |
| -from typing import Generator, Literal, Optional, TypeAlias, Union, cast |
| 4 | +from typing import Generator, Literal, Optional, TypeAlias, Union |
5 | 5 |
|
6 | 6 | from pydantic import (
|
7 | 7 | AliasChoices,
|
|
15 | 15 | )
|
16 | 16 | from pydantic_core import to_jsonable_python
|
17 | 17 |
|
18 |
| -from invokeai.app.invocations.baseinvocation import BaseInvocation |
19 | 18 | from invokeai.app.invocations.fields import ImageField
|
20 | 19 | from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
21 | 20 | from invokeai.app.services.workflow_records.workflow_records_common import (
|
@@ -137,20 +136,18 @@ def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
137 | 136 | return v
|
138 | 137 |
|
139 | 138 | @model_validator(mode="after")
|
140 |
| - def validate_batch_nodes_and_edges(cls, values): |
141 |
| - batch_data_collection = cast(Optional[BatchDataCollection], values.data) |
142 |
| - if batch_data_collection is None: |
143 |
| - return values |
144 |
| - graph = cast(Graph, values.graph) |
145 |
| - for batch_data_list in batch_data_collection: |
| 139 | + def validate_batch_nodes_and_edges(self): |
| 140 | + if self.data is None: |
| 141 | + return self |
| 142 | + for batch_data_list in self.data: |
146 | 143 | for batch_data in batch_data_list:
|
147 | 144 | try:
|
148 |
| - node = cast(BaseInvocation, graph.get_node(batch_data.node_path)) |
| 145 | + node = self.graph.get_node(batch_data.node_path) |
149 | 146 | except NodeNotFoundError:
|
150 | 147 | raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
151 | 148 | if batch_data.field_name not in type(node).model_fields:
|
152 | 149 | raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
153 |
| - return values |
| 150 | + return self |
154 | 151 |
|
155 | 152 | @field_validator("graph")
|
156 | 153 | def validate_graph(cls, v: Graph):
|
|
0 commit comments