Skip to content

Commit 150f86d

Browse files
Extend llm_context_node definition (#5333)
* Updated node definition Signed-off-by: Elena Khaustova <ymax70rus@gmail.com> * Added unit tests Signed-off-by: Elena Khaustova <ymax70rus@gmail.com> --------- Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>
1 parent 1a4e1eb commit 150f86d

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

kedro/pipeline/llm_context.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
time and automatically assigned readable names based on the returned objects.
1515
"""
1616

17-
from collections.abc import Callable, Sequence
17+
from collections.abc import Callable, Iterable, Sequence
1818
from dataclasses import dataclass, field
1919
from typing import Any, NamedTuple, TypeVar
2020

@@ -143,14 +143,17 @@ class LLMContextNode(Node):
143143
```
144144
"""
145145

146-
def __init__(
146+
def __init__( # noqa: PLR0913
147147
self,
148148
*,
149149
outputs: str,
150150
llm: str,
151151
prompts: list[str],
152152
tools: list[_ToolConfig] | None = None,
153153
name: str | None = None,
154+
tags: str | Iterable[str] | None = None,
155+
confirms: str | list[str] | None = None,
156+
namespace: str | None = None,
154157
):
155158
"""Create an LLMContextNode.
156159
@@ -160,6 +163,10 @@ def __init__(
160163
prompts: List of dataset names containing prompt content.
161164
tools: Optional list of tool configurations created via `tool(...)`.
162165
name: Optional node name; also used as the logical context identifier.
166+
tags: Optional set of tags to be applied to the node.
167+
confirms: Optional name or the list of the names of the datasets
168+
that should be confirmed.
169+
namespace: Optional node namespace.
163170
"""
164171
inputs = {"llm": llm}
165172

@@ -201,18 +208,27 @@ def construct_context(llm: object, **kwargs: dict[str, Any]) -> LLMContext:
201208

202209
# call the Node constructor with the func, inputs, outputs, name
203210
super().__init__(
204-
func=construct_context, inputs=inputs, outputs=outputs, name=name
211+
func=construct_context,
212+
inputs=inputs,
213+
outputs=outputs,
214+
name=name,
215+
tags=tags,
216+
confirms=confirms,
217+
namespace=namespace,
205218
)
206219

207220

208221
@experimental
209-
def llm_context_node(
222+
def llm_context_node( # noqa: PLR0913
210223
*,
211224
outputs: str,
212225
llm: str,
213226
prompts: list[str],
214227
tools: list[_ToolConfig] | None = None,
215228
name: str | None = None,
229+
tags: str | Iterable[str] | None = None,
230+
confirms: str | list[str] | None = None,
231+
namespace: str | None = None,
216232
) -> Node:
217233
"""
218234
!!! warning "Experimental"
@@ -229,7 +245,12 @@ def llm_context_node(
229245
prompts: List of dataset names containing prompt content.
230246
tools: Optional list of tool configurations created via `tool(...)`.
231247
Each tool declares the Kedro inputs required to construct it.
232-
name: Optional name for the node and for the created context.
248+
name: Optional node name; also used as the logical context identifier.
249+
tags: Optional set of tags to be applied to the node.
250+
confirms: Optional name or the list of the names of the datasets
251+
that should be confirmed.
252+
namespace: Optional node namespace.
253+
233254
234255
Returns:
235256
A Kedro Node that loads all declared datasets, instantiates tools,
@@ -256,4 +277,7 @@ def llm_context_node(
256277
prompts=prompts,
257278
tools=tools,
258279
name=name,
280+
tags=tags,
281+
confirms=confirms,
282+
namespace=namespace,
259283
)

tests/pipeline/test_llm_context.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from kedro.pipeline.llm_context import (
44
LLMContext,
5+
LLMContextNode,
56
_get_tool_name,
67
_normalize_outputs,
78
llm_context_node,
@@ -185,3 +186,39 @@ class DummyClass:
185186
def test_get_tool_name_all_branches(obj, expected_name):
186187
"""_get_tool_name should derive a stable, human-friendly tool name."""
187188
assert _get_tool_name(obj) == expected_name
189+
190+
191+
def test_llm_context_node_sets_tags():
192+
"""LLMContextNode should propagate tags to the underlying Node."""
193+
node_obj = LLMContextNode(
194+
outputs="out",
195+
llm="llm",
196+
prompts=[],
197+
tags=["llm", "experimental"],
198+
)
199+
200+
assert node_obj.tags == {"llm", "experimental"}
201+
202+
203+
def test_llm_context_node_sets_confirms():
204+
"""LLMContextNode should propagate confirms to the underlying Node."""
205+
node_obj = LLMContextNode(
206+
outputs="out",
207+
llm="llm",
208+
prompts=[],
209+
confirms=["llm", "prompt"],
210+
)
211+
212+
assert node_obj.confirms == ["llm", "prompt"]
213+
214+
215+
def test_llm_context_node_sets_namespace():
216+
"""LLMContextNode should propagate namespace to the underlying Node."""
217+
node_obj = LLMContextNode(
218+
outputs="out",
219+
llm="llm",
220+
prompts=[],
221+
namespace="llm_nodes",
222+
)
223+
224+
assert node_obj.namespace == "llm_nodes"

0 commit comments

Comments
 (0)