|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import Any |
| 4 | +from unittest.mock import MagicMock, patch |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from core.app.apps.base_app_queue_manager import AppQueueManager |
| 9 | +from core.app.apps.workflow.app_runner import WorkflowAppRunner |
| 10 | +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity |
| 11 | +from core.workflow.runtime import GraphRuntimeState, VariablePool |
| 12 | +from core.workflow.system_variable import SystemVariable |
| 13 | +from models.workflow import Workflow |
| 14 | + |
| 15 | + |
| 16 | +def _make_graph_state(): |
| 17 | + variable_pool = VariablePool( |
| 18 | + system_variables=SystemVariable.default(), |
| 19 | + user_inputs={}, |
| 20 | + environment_variables=[], |
| 21 | + conversation_variables=[], |
| 22 | + ) |
| 23 | + return MagicMock(), variable_pool, GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.parametrize( |
| 27 | + ("single_iteration_run", "single_loop_run"), |
| 28 | + [ |
| 29 | + (WorkflowAppGenerateEntity.SingleIterationRunEntity(node_id="iter", inputs={}), None), |
| 30 | + (None, WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id="loop", inputs={})), |
| 31 | + ], |
| 32 | +) |
| 33 | +def test_run_uses_single_node_execution_branch( |
| 34 | + single_iteration_run: Any, |
| 35 | + single_loop_run: Any, |
| 36 | +) -> None: |
| 37 | + app_config = MagicMock() |
| 38 | + app_config.app_id = "app" |
| 39 | + app_config.tenant_id = "tenant" |
| 40 | + app_config.workflow_id = "workflow" |
| 41 | + |
| 42 | + app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity) |
| 43 | + app_generate_entity.app_config = app_config |
| 44 | + app_generate_entity.inputs = {} |
| 45 | + app_generate_entity.files = [] |
| 46 | + app_generate_entity.user_id = "user" |
| 47 | + app_generate_entity.invoke_from = InvokeFrom.SERVICE_API |
| 48 | + app_generate_entity.workflow_execution_id = "execution-id" |
| 49 | + app_generate_entity.task_id = "task-id" |
| 50 | + app_generate_entity.call_depth = 0 |
| 51 | + app_generate_entity.trace_manager = None |
| 52 | + app_generate_entity.single_iteration_run = single_iteration_run |
| 53 | + app_generate_entity.single_loop_run = single_loop_run |
| 54 | + |
| 55 | + workflow = MagicMock(spec=Workflow) |
| 56 | + workflow.tenant_id = "tenant" |
| 57 | + workflow.app_id = "app" |
| 58 | + workflow.id = "workflow" |
| 59 | + workflow.type = "workflow" |
| 60 | + workflow.version = "v1" |
| 61 | + workflow.graph_dict = {"nodes": [], "edges": []} |
| 62 | + workflow.environment_variables = [] |
| 63 | + |
| 64 | + runner = WorkflowAppRunner( |
| 65 | + application_generate_entity=app_generate_entity, |
| 66 | + queue_manager=MagicMock(spec=AppQueueManager), |
| 67 | + variable_loader=MagicMock(), |
| 68 | + workflow=workflow, |
| 69 | + system_user_id="system-user", |
| 70 | + workflow_execution_repository=MagicMock(), |
| 71 | + workflow_node_execution_repository=MagicMock(), |
| 72 | + ) |
| 73 | + |
| 74 | + graph, variable_pool, graph_runtime_state = _make_graph_state() |
| 75 | + mock_workflow_entry = MagicMock() |
| 76 | + mock_workflow_entry.graph_engine = MagicMock() |
| 77 | + mock_workflow_entry.graph_engine.layer = MagicMock() |
| 78 | + mock_workflow_entry.run.return_value = iter([]) |
| 79 | + |
| 80 | + with ( |
| 81 | + patch("core.app.apps.workflow.app_runner.RedisChannel"), |
| 82 | + patch("core.app.apps.workflow.app_runner.redis_client"), |
| 83 | + patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry) as entry_class, |
| 84 | + patch.object( |
| 85 | + runner, |
| 86 | + "_prepare_single_node_execution", |
| 87 | + return_value=( |
| 88 | + graph, |
| 89 | + variable_pool, |
| 90 | + graph_runtime_state, |
| 91 | + ), |
| 92 | + ) as prepare_single, |
| 93 | + patch.object(runner, "_init_graph") as init_graph, |
| 94 | + ): |
| 95 | + runner.run() |
| 96 | + |
| 97 | + prepare_single.assert_called_once_with( |
| 98 | + workflow=workflow, |
| 99 | + single_iteration_run=single_iteration_run, |
| 100 | + single_loop_run=single_loop_run, |
| 101 | + ) |
| 102 | + init_graph.assert_not_called() |
| 103 | + |
| 104 | + entry_kwargs = entry_class.call_args.kwargs |
| 105 | + assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER |
| 106 | + assert entry_kwargs["variable_pool"] is variable_pool |
| 107 | + assert entry_kwargs["graph_runtime_state"] is graph_runtime_state |
0 commit comments