Skip to content

Commit d162f7e

Browse files
authored
feat(api): automatically NODE_TYPE_CLASSES_MAPPING generation from node class definitions (langgenius#28525)
1 parent 2f8cb2a commit d162f7e

File tree

11 files changed

+245
-189
lines changed

11 files changed

+245
-189
lines changed

api/app_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
5151
ext_commands,
5252
ext_compress,
5353
ext_database,
54+
ext_forward_refs,
5455
ext_hosting_provider,
5556
ext_import_modules,
5657
ext_logging,
@@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
7576
ext_warnings,
7677
ext_import_modules,
7778
ext_orjson,
79+
ext_forward_refs,
7880
ext_set_secretkey,
7981
ext_compress,
8082
ext_code_based_extension,

api/core/app/entities/app_invoke_entities.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel):
130130
# extra parameters, like: auto_generate_conversation_name
131131
extras: dict[str, Any] = Field(default_factory=dict)
132132

133-
# tracing instance
133+
# tracing instance; use forward ref to avoid circular import at import time
134134
trace_manager: Optional["TraceQueueManager"] = None
135135

136136

@@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
275275
start_node_id: str | None = None
276276

277277

278-
# Import TraceQueueManager at runtime to resolve forward references
279-
from core.ops.ops_trace_manager import TraceQueueManager
278+
# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports.
279+
# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to
280+
# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup.
280281

281-
# Rebuild models that use forward references
282-
AppGenerateEntity.model_rebuild()
283-
EasyUIBasedAppGenerateEntity.model_rebuild()
284-
ConversationAppGenerateEntity.model_rebuild()
285-
ChatAppGenerateEntity.model_rebuild()
286-
CompletionAppGenerateEntity.model_rebuild()
287-
AgentChatAppGenerateEntity.model_rebuild()
288-
AdvancedChatAppGenerateEntity.model_rebuild()
289-
WorkflowAppGenerateEntity.model_rebuild()
290-
RagPipelineGenerateEntity.model_rebuild()
282+
283+
# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet.
284+
class _TraceQueueManagerStub:
285+
pass
286+
287+
288+
_ns = {"TraceQueueManager": _TraceQueueManagerStub}
289+
AppGenerateEntity.model_rebuild(_types_namespace=_ns)
290+
EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns)
291+
ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns)
292+
ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
293+
CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns)
294+
AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
295+
AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
296+
WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns)
297+
RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns)

api/core/workflow/nodes/base/node.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import importlib
12
import logging
3+
import operator
4+
import pkgutil
25
from abc import abstractmethod
36
from collections.abc import Generator, Mapping, Sequence
47
from functools import singledispatchmethod
8+
from types import MappingProxyType
59
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
610
from uuid import uuid4
711

@@ -134,6 +138,34 @@ class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
134138

135139
cls._node_data_type = node_data_type
136140

141+
# Skip base class itself
142+
if cls is Node:
143+
return
144+
# Only register production node implementations defined under core.workflow.nodes.*
145+
# This prevents test helper subclasses from polluting the global registry and
146+
# accidentally overriding real node types (e.g., a test Answer node).
147+
module_name = getattr(cls, "__module__", "")
148+
# Only register concrete subclasses that define node_type and version()
149+
node_type = cls.node_type
150+
version = cls.version()
151+
bucket = Node._registry.setdefault(node_type, {})
152+
if module_name.startswith("core.workflow.nodes."):
153+
# Production node definitions take precedence and may override
154+
bucket[version] = cls # type: ignore[index]
155+
else:
156+
# External/test subclasses may register but must not override production
157+
bucket.setdefault(version, cls) # type: ignore[index]
158+
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
159+
version_keys = [v for v in bucket if v != "latest"]
160+
numeric_pairs: list[tuple[str, int]] = []
161+
for v in version_keys:
162+
numeric_pairs.append((v, int(v)))
163+
if numeric_pairs:
164+
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
165+
else:
166+
latest_key = max(version_keys) if version_keys else version
167+
bucket["latest"] = bucket[latest_key]
168+
137169
@classmethod
138170
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
139171
"""
@@ -165,6 +197,9 @@ def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
165197

166198
return None
167199

200+
# Global registry populated via __init_subclass__
201+
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
202+
168203
def __init__(
169204
self,
170205
id: str,
@@ -395,6 +430,29 @@ def version(cls) -> str:
395430
# in `api/core/workflow/nodes/__init__.py`.
396431
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
397432

433+
@classmethod
434+
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
435+
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
436+
437+
Import all modules under core.workflow.nodes so subclasses register themselves on import.
438+
Then we return a readonly view of the registry to avoid accidental mutation.
439+
"""
440+
# Import all node modules to ensure they are loaded (thus registered)
441+
import core.workflow.nodes as _nodes_pkg
442+
443+
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
444+
# Avoid importing modules that depend on the registry to prevent circular imports
445+
# e.g. node_factory imports node_mapping which builds the mapping here.
446+
if _modname in {
447+
"core.workflow.nodes.node_factory",
448+
"core.workflow.nodes.node_mapping",
449+
}:
450+
continue
451+
importlib.import_module(_modname)
452+
453+
# Return a readonly view so callers can't mutate the registry by accident
454+
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
455+
398456
@property
399457
def retry(self) -> bool:
400458
return False
Lines changed: 2 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,165 +1,9 @@
11
from collections.abc import Mapping
22

33
from core.workflow.enums import NodeType
4-
from core.workflow.nodes.agent.agent_node import AgentNode
5-
from core.workflow.nodes.answer.answer_node import AnswerNode
64
from core.workflow.nodes.base.node import Node
7-
from core.workflow.nodes.code import CodeNode
8-
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
9-
from core.workflow.nodes.document_extractor import DocumentExtractorNode
10-
from core.workflow.nodes.end.end_node import EndNode
11-
from core.workflow.nodes.http_request import HttpRequestNode
12-
from core.workflow.nodes.human_input import HumanInputNode
13-
from core.workflow.nodes.if_else import IfElseNode
14-
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
15-
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
16-
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
17-
from core.workflow.nodes.list_operator import ListOperatorNode
18-
from core.workflow.nodes.llm import LLMNode
19-
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
20-
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
21-
from core.workflow.nodes.question_classifier import QuestionClassifierNode
22-
from core.workflow.nodes.start import StartNode
23-
from core.workflow.nodes.template_transform import TemplateTransformNode
24-
from core.workflow.nodes.tool import ToolNode
25-
from core.workflow.nodes.trigger_plugin import TriggerEventNode
26-
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
27-
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
28-
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
29-
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
30-
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
315

326
LATEST_VERSION = "latest"
337

34-
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
35-
# Specifically, if you have introduced new node types, you should add them here.
36-
#
37-
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
38-
# hook. Try to avoid duplication of node information.
39-
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
40-
NodeType.START: {
41-
LATEST_VERSION: StartNode,
42-
"1": StartNode,
43-
},
44-
NodeType.END: {
45-
LATEST_VERSION: EndNode,
46-
"1": EndNode,
47-
},
48-
NodeType.ANSWER: {
49-
LATEST_VERSION: AnswerNode,
50-
"1": AnswerNode,
51-
},
52-
NodeType.LLM: {
53-
LATEST_VERSION: LLMNode,
54-
"1": LLMNode,
55-
},
56-
NodeType.KNOWLEDGE_RETRIEVAL: {
57-
LATEST_VERSION: KnowledgeRetrievalNode,
58-
"1": KnowledgeRetrievalNode,
59-
},
60-
NodeType.IF_ELSE: {
61-
LATEST_VERSION: IfElseNode,
62-
"1": IfElseNode,
63-
},
64-
NodeType.CODE: {
65-
LATEST_VERSION: CodeNode,
66-
"1": CodeNode,
67-
},
68-
NodeType.TEMPLATE_TRANSFORM: {
69-
LATEST_VERSION: TemplateTransformNode,
70-
"1": TemplateTransformNode,
71-
},
72-
NodeType.QUESTION_CLASSIFIER: {
73-
LATEST_VERSION: QuestionClassifierNode,
74-
"1": QuestionClassifierNode,
75-
},
76-
NodeType.HTTP_REQUEST: {
77-
LATEST_VERSION: HttpRequestNode,
78-
"1": HttpRequestNode,
79-
},
80-
NodeType.TOOL: {
81-
LATEST_VERSION: ToolNode,
82-
# This is an issue that caused problems before.
83-
# Logically, we shouldn't use two different versions to point to the same class here,
84-
# but in order to maintain compatibility with historical data, this approach has been retained.
85-
"2": ToolNode,
86-
"1": ToolNode,
87-
},
88-
NodeType.VARIABLE_AGGREGATOR: {
89-
LATEST_VERSION: VariableAggregatorNode,
90-
"1": VariableAggregatorNode,
91-
},
92-
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
93-
LATEST_VERSION: VariableAggregatorNode,
94-
"1": VariableAggregatorNode,
95-
}, # original name of VARIABLE_AGGREGATOR
96-
NodeType.ITERATION: {
97-
LATEST_VERSION: IterationNode,
98-
"1": IterationNode,
99-
},
100-
NodeType.ITERATION_START: {
101-
LATEST_VERSION: IterationStartNode,
102-
"1": IterationStartNode,
103-
},
104-
NodeType.LOOP: {
105-
LATEST_VERSION: LoopNode,
106-
"1": LoopNode,
107-
},
108-
NodeType.LOOP_START: {
109-
LATEST_VERSION: LoopStartNode,
110-
"1": LoopStartNode,
111-
},
112-
NodeType.LOOP_END: {
113-
LATEST_VERSION: LoopEndNode,
114-
"1": LoopEndNode,
115-
},
116-
NodeType.PARAMETER_EXTRACTOR: {
117-
LATEST_VERSION: ParameterExtractorNode,
118-
"1": ParameterExtractorNode,
119-
},
120-
NodeType.VARIABLE_ASSIGNER: {
121-
LATEST_VERSION: VariableAssignerNodeV2,
122-
"1": VariableAssignerNodeV1,
123-
"2": VariableAssignerNodeV2,
124-
},
125-
NodeType.DOCUMENT_EXTRACTOR: {
126-
LATEST_VERSION: DocumentExtractorNode,
127-
"1": DocumentExtractorNode,
128-
},
129-
NodeType.LIST_OPERATOR: {
130-
LATEST_VERSION: ListOperatorNode,
131-
"1": ListOperatorNode,
132-
},
133-
NodeType.AGENT: {
134-
LATEST_VERSION: AgentNode,
135-
# This is an issue that caused problems before.
136-
# Logically, we shouldn't use two different versions to point to the same class here,
137-
# but in order to maintain compatibility with historical data, this approach has been retained.
138-
"2": AgentNode,
139-
"1": AgentNode,
140-
},
141-
NodeType.HUMAN_INPUT: {
142-
LATEST_VERSION: HumanInputNode,
143-
"1": HumanInputNode,
144-
},
145-
NodeType.DATASOURCE: {
146-
LATEST_VERSION: DatasourceNode,
147-
"1": DatasourceNode,
148-
},
149-
NodeType.KNOWLEDGE_INDEX: {
150-
LATEST_VERSION: KnowledgeIndexNode,
151-
"1": KnowledgeIndexNode,
152-
},
153-
NodeType.TRIGGER_WEBHOOK: {
154-
LATEST_VERSION: TriggerWebhookNode,
155-
"1": TriggerWebhookNode,
156-
},
157-
NodeType.TRIGGER_PLUGIN: {
158-
LATEST_VERSION: TriggerEventNode,
159-
"1": TriggerEventNode,
160-
},
161-
NodeType.TRIGGER_SCHEDULE: {
162-
LATEST_VERSION: TriggerScheduleNode,
163-
"1": TriggerScheduleNode,
164-
},
165-
}
8+
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
9+
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()

api/core/workflow/nodes/tool/tool_node.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from core.tools.errors import ToolInvokeError
1313
from core.tools.tool_engine import ToolEngine
1414
from core.tools.utils.message_transformer import ToolFileMessageTransformer
15-
from core.tools.workflow_as_tool.tool import WorkflowTool
1615
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
1716
from core.variables.variables import ArrayAnyVariable
1817
from core.workflow.enums import (
@@ -430,7 +429,7 @@ def _transform_message(
430429
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
431430
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
432431
}
433-
if usage.total_tokens > 0:
432+
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
434433
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
435434
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
436435
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
@@ -449,8 +448,17 @@ def _transform_message(
449448

450449
@staticmethod
451450
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
452-
if isinstance(tool_runtime, WorkflowTool):
453-
return tool_runtime.latest_usage
451+
# Avoid importing WorkflowTool at module import time; rely on duck typing
452+
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
453+
latest = getattr(tool_runtime, "latest_usage", None)
454+
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
455+
# for any name, so we must type-check here.
456+
if isinstance(latest, LLMUsage):
457+
return latest
458+
if isinstance(latest, dict):
459+
# Allow dict payloads from external runtimes
460+
return LLMUsage.model_validate(latest)
461+
# Fallback to empty usage when attribute is missing or not a valid payload
454462
return LLMUsage.empty_usage()
455463

456464
@classmethod

0 commit comments

Comments
 (0)