Skip to content

Commit 3d41467

Browse files
laipz8200Yeuoly
andauthored
fix(graph_engine): Cannot run single iteration or loop node (#31470)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
1 parent d76ad15 commit 3d41467

File tree

14 files changed

+286
-54
lines changed

14 files changed

+286
-54
lines changed

api/core/app/apps/workflow_app_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _prepare_single_node_execution(
157157
# Create initial runtime state with variable pool containing environment variables
158158
graph_runtime_state = GraphRuntimeState(
159159
variable_pool=VariablePool(
160-
system_variables=SystemVariable.empty(),
160+
system_variables=SystemVariable.default(),
161161
user_inputs={},
162162
environment_variables=workflow.environment_variables,
163163
),
@@ -272,7 +272,9 @@ def _get_graph_and_variable_pool_for_single_node_run(
272272
)
273273

274274
# init graph
275-
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
275+
graph = Graph.init(
276+
graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
277+
)
276278

277279
if not graph:
278280
raise ValueError("graph not found in workflow")

api/core/workflow/graph/graph.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def init(
288288
graph_config: Mapping[str, object],
289289
node_factory: NodeFactory,
290290
root_node_id: str | None = None,
291+
skip_validation: bool = False,
291292
) -> Graph:
292293
"""
293294
Initialize graph
@@ -339,8 +340,9 @@ def init(
339340
root_node=root_node,
340341
)
341342

342-
# Validate the graph structure using built-in validators
343-
get_graph_validator().validate(graph)
343+
if not skip_validation:
344+
# Validate the graph structure using built-in validators
345+
get_graph_validator().validate(graph)
344346

345347
return graph
346348

api/core/workflow/runtime/variable_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class VariablePool(BaseModel):
4444
)
4545
system_variables: SystemVariable = Field(
4646
description="System variables",
47-
default_factory=SystemVariable.empty,
47+
default_factory=SystemVariable.default,
4848
)
4949
environment_variables: Sequence[Variable] = Field(
5050
description="Environment variables.",
@@ -271,4 +271,4 @@ def _add_system_variables(self, system_variable: SystemVariable):
271271
@classmethod
272272
def empty(cls) -> VariablePool:
273273
"""Create an empty variable pool."""
274-
return cls(system_variables=SystemVariable.empty())
274+
return cls(system_variables=SystemVariable.default())

api/core/workflow/system_variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Mapping, Sequence
44
from types import MappingProxyType
55
from typing import Any
6+
from uuid import uuid4
67

78
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
89

@@ -72,8 +73,8 @@ def validate_json_fields(cls, data):
7273
return data
7374

7475
@classmethod
75-
def empty(cls) -> SystemVariable:
76-
return cls()
76+
def default(cls) -> SystemVariable:
77+
return cls(workflow_execution_id=str(uuid4()))
7778

7879
def to_dict(self) -> dict[SystemVariableKey, Any]:
7980
# NOTE: This method is provided for compatibility with legacy code.

api/core/workflow/workflow_entry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def run_free_node(
277277

278278
# init variable pool
279279
variable_pool = VariablePool(
280-
system_variables=SystemVariable.empty(),
280+
system_variables=SystemVariable.default(),
281281
user_inputs={},
282282
environment_variables=[],
283283
)

api/services/rag_pipeline/rag_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def run_draft_workflow_node(
436436
user_inputs=user_inputs,
437437
user_id=account.id,
438438
variable_pool=VariablePool(
439-
system_variables=SystemVariable.empty(),
439+
system_variables=SystemVariable.default(),
440440
user_inputs=user_inputs,
441441
environment_variables=[],
442442
conversation_variables=[],

api/services/workflow_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def run_draft_workflow_node(
675675

676676
else:
677677
variable_pool = VariablePool(
678-
system_variables=SystemVariable.empty(),
678+
system_variables=SystemVariable.default(),
679679
user_inputs=user_inputs,
680680
environment_variables=draft_workflow.environment_variables,
681681
conversation_variables=[],
@@ -1063,7 +1063,7 @@ def _setup_variable_pool(
10631063
system_variable.conversation_id = conversation_id
10641064
system_variable.dialogue_count = 1
10651065
else:
1066-
system_variable = SystemVariable.empty()
1066+
system_variable = SystemVariable.default()
10671067

10681068
# init variable pool
10691069
variable_pool = VariablePool(
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
from core.app.apps.base_app_queue_manager import AppQueueManager
9+
from core.app.apps.workflow.app_runner import WorkflowAppRunner
10+
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
11+
from core.workflow.runtime import GraphRuntimeState, VariablePool
12+
from core.workflow.system_variable import SystemVariable
13+
from models.workflow import Workflow
14+
15+
16+
def _make_graph_state():
17+
variable_pool = VariablePool(
18+
system_variables=SystemVariable.default(),
19+
user_inputs={},
20+
environment_variables=[],
21+
conversation_variables=[],
22+
)
23+
return MagicMock(), variable_pool, GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
24+
25+
26+
@pytest.mark.parametrize(
27+
("single_iteration_run", "single_loop_run"),
28+
[
29+
(WorkflowAppGenerateEntity.SingleIterationRunEntity(node_id="iter", inputs={}), None),
30+
(None, WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id="loop", inputs={})),
31+
],
32+
)
33+
def test_run_uses_single_node_execution_branch(
34+
single_iteration_run: Any,
35+
single_loop_run: Any,
36+
) -> None:
37+
app_config = MagicMock()
38+
app_config.app_id = "app"
39+
app_config.tenant_id = "tenant"
40+
app_config.workflow_id = "workflow"
41+
42+
app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity)
43+
app_generate_entity.app_config = app_config
44+
app_generate_entity.inputs = {}
45+
app_generate_entity.files = []
46+
app_generate_entity.user_id = "user"
47+
app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
48+
app_generate_entity.workflow_execution_id = "execution-id"
49+
app_generate_entity.task_id = "task-id"
50+
app_generate_entity.call_depth = 0
51+
app_generate_entity.trace_manager = None
52+
app_generate_entity.single_iteration_run = single_iteration_run
53+
app_generate_entity.single_loop_run = single_loop_run
54+
55+
workflow = MagicMock(spec=Workflow)
56+
workflow.tenant_id = "tenant"
57+
workflow.app_id = "app"
58+
workflow.id = "workflow"
59+
workflow.type = "workflow"
60+
workflow.version = "v1"
61+
workflow.graph_dict = {"nodes": [], "edges": []}
62+
workflow.environment_variables = []
63+
64+
runner = WorkflowAppRunner(
65+
application_generate_entity=app_generate_entity,
66+
queue_manager=MagicMock(spec=AppQueueManager),
67+
variable_loader=MagicMock(),
68+
workflow=workflow,
69+
system_user_id="system-user",
70+
workflow_execution_repository=MagicMock(),
71+
workflow_node_execution_repository=MagicMock(),
72+
)
73+
74+
graph, variable_pool, graph_runtime_state = _make_graph_state()
75+
mock_workflow_entry = MagicMock()
76+
mock_workflow_entry.graph_engine = MagicMock()
77+
mock_workflow_entry.graph_engine.layer = MagicMock()
78+
mock_workflow_entry.run.return_value = iter([])
79+
80+
with (
81+
patch("core.app.apps.workflow.app_runner.RedisChannel"),
82+
patch("core.app.apps.workflow.app_runner.redis_client"),
83+
patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry) as entry_class,
84+
patch.object(
85+
runner,
86+
"_prepare_single_node_execution",
87+
return_value=(
88+
graph,
89+
variable_pool,
90+
graph_runtime_state,
91+
),
92+
) as prepare_single,
93+
patch.object(runner, "_init_graph") as init_graph,
94+
):
95+
runner.run()
96+
97+
prepare_single.assert_called_once_with(
98+
workflow=workflow,
99+
single_iteration_run=single_iteration_run,
100+
single_loop_run=single_loop_run,
101+
)
102+
init_graph.assert_not_called()
103+
104+
entry_kwargs = entry_class.call_args.kwargs
105+
assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER
106+
assert entry_kwargs["variable_pool"] is variable_pool
107+
assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pytest
6+
7+
from core.app.entities.app_invoke_entities import InvokeFrom
8+
from core.app.workflow.node_factory import DifyNodeFactory
9+
from core.workflow.entities import GraphInitParams
10+
from core.workflow.graph import Graph
11+
from core.workflow.graph.validation import GraphValidationError
12+
from core.workflow.nodes import NodeType
13+
from core.workflow.runtime import GraphRuntimeState, VariablePool
14+
from core.workflow.system_variable import SystemVariable
15+
from models.enums import UserFrom
16+
17+
18+
def _build_iteration_graph(node_id: str) -> dict[str, Any]:
19+
return {
20+
"nodes": [
21+
{
22+
"id": node_id,
23+
"data": {
24+
"type": "iteration",
25+
"title": "Iteration",
26+
"iterator_selector": ["start", "items"],
27+
"output_selector": [node_id, "output"],
28+
},
29+
}
30+
],
31+
"edges": [],
32+
}
33+
34+
35+
def _build_loop_graph(node_id: str) -> dict[str, Any]:
36+
return {
37+
"nodes": [
38+
{
39+
"id": node_id,
40+
"data": {
41+
"type": "loop",
42+
"title": "Loop",
43+
"loop_count": 1,
44+
"break_conditions": [],
45+
"logical_operator": "and",
46+
"loop_variables": [],
47+
"outputs": {},
48+
},
49+
}
50+
],
51+
"edges": [],
52+
}
53+
54+
55+
def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
56+
graph_init_params = GraphInitParams(
57+
tenant_id="tenant",
58+
app_id="app",
59+
workflow_id="workflow",
60+
graph_config=graph_config,
61+
user_id="user",
62+
user_from=UserFrom.ACCOUNT,
63+
invoke_from=InvokeFrom.DEBUGGER,
64+
call_depth=0,
65+
)
66+
graph_runtime_state = GraphRuntimeState(
67+
variable_pool=VariablePool(
68+
system_variables=SystemVariable.default(),
69+
user_inputs={},
70+
environment_variables=[],
71+
),
72+
start_at=0.0,
73+
)
74+
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
75+
76+
77+
def test_iteration_root_requires_skip_validation():
78+
node_id = "iteration-node"
79+
graph_config = _build_iteration_graph(node_id)
80+
node_factory = _make_factory(graph_config)
81+
82+
with pytest.raises(GraphValidationError):
83+
Graph.init(
84+
graph_config=graph_config,
85+
node_factory=node_factory,
86+
root_node_id=node_id,
87+
)
88+
89+
graph = Graph.init(
90+
graph_config=graph_config,
91+
node_factory=node_factory,
92+
root_node_id=node_id,
93+
skip_validation=True,
94+
)
95+
96+
assert graph.root_node.id == node_id
97+
assert graph.root_node.node_type == NodeType.ITERATION
98+
99+
100+
def test_loop_root_requires_skip_validation():
101+
node_id = "loop-node"
102+
graph_config = _build_loop_graph(node_id)
103+
node_factory = _make_factory(graph_config)
104+
105+
with pytest.raises(GraphValidationError):
106+
Graph.init(
107+
graph_config=graph_config,
108+
node_factory=node_factory,
109+
root_node_id=node_id,
110+
)
111+
112+
graph = Graph.init(
113+
graph_config=graph_config,
114+
node_factory=node_factory,
115+
root_node_id=node_id,
116+
skip_validation=True,
117+
)
118+
119+
assert graph.root_node.id == node_id
120+
assert graph.root_node.node_type == NodeType.LOOP

0 commit comments

Comments
 (0)