Skip to content

Commit 2a1a134

Browse files
feat(datasets): Add LangchainPromptDataset to experimental datasets (#1200)
* Add LangchainPromptDataset to experimental datasets Signed-off-by: Laura Couto <[email protected]> * Add credential handling Signed-off-by: Laura Couto <[email protected]> * Lint Signed-off-by: Laura Couto <[email protected]> * Cleanup Signed-off-by: Laura Couto <[email protected]> * Separate validation from _create_chat_prompt_template Signed-off-by: Laura Couto <[email protected]> * Change validation function to not try to validate the template format Signed-off-by: Laura Couto <[email protected]> * Add unit tests for LangChainPromptDataset Signed-off-by: Laura Couto <[email protected]> * Map constant template type to function Signed-off-by: Laura Couto <[email protected]> * Better docstrings Signed-off-by: Laura Couto <[email protected]> * Add LangChainPromptDataset to release notes Signed-off-by: Laura Couto <[email protected]> * Add new dataset to docs Signed-off-by: Laura Couto <[email protected]> * Add new dataset to docs index Signed-off-by: Laura Couto <[email protected]> * Fix mkdocs error Signed-off-by: Laura Couto <[email protected]> * Add preview method Signed-off-by: Laura Couto <[email protected]> * Fix preview method, should work on Viz now Signed-off-by: Laura Couto <[email protected]> * Add requirements to pyproject.toml Signed-off-by: Laura Couto <[email protected]> * Improve docstrings Signed-off-by: Laura Couto <[email protected]> * Add LangchainPromptDataset to experimental datasets Signed-off-by: Laura Couto <[email protected]> * Add credential handling Signed-off-by: Laura Couto <[email protected]> * Lint Signed-off-by: Laura Couto <[email protected]> * Cleanup Signed-off-by: Laura Couto <[email protected]> * Separate validation from _create_chat_prompt_template Signed-off-by: Laura Couto <[email protected]> * Change validation function to not try to validate the template format Signed-off-by: Laura Couto <[email protected]> * Add unit tests for LangChainPromptDataset Signed-off-by: Laura Couto <[email protected]> * Map constant template type to function Signed-off-by: Laura Couto <[email protected]> * Better docstrings Signed-off-by: Laura Couto <[email protected]> * Add LangChainPromptDataset to release notes Signed-off-by: Laura Couto <[email protected]> * Add new dataset to docs Signed-off-by: Laura Couto <[email protected]> * Add new dataset to docs index Signed-off-by: Laura Couto <[email protected]> * Fix mkdocs error Signed-off-by: Laura Couto <[email protected]> * Add preview method Signed-off-by: Laura Couto <[email protected]> * Fix preview method, should work on Viz now Signed-off-by: Laura Couto <[email protected]> * Add requirements to pyproject.toml Signed-off-by: Laura Couto <[email protected]> * Improve docstrings Signed-off-by: Laura Couto <[email protected]> * Fix return type on validate function Signed-off-by: Laura Couto <[email protected]> * Delete coverage.xml Signed-off-by: L. R. Couto <[email protected]> * Remove coverage files that shouldn't be there Signed-off-by: Laura Couto <[email protected]> * Simplify preview function Signed-off-by: Laura Couto <[email protected]> * Add better docstring to class Signed-off-by: Laura Couto <[email protected]> * Lower required langchain version Signed-off-by: Laura Couto <[email protected]> * Lint Signed-off-by: Laura Couto <[email protected]> * Update kedro-datasets/kedro_datasets_experimental/langchain/__init__.py Co-authored-by: ElenaKhaustova <[email protected]> Signed-off-by: L. R. Couto <[email protected]> * Improve docstring Signed-off-by: Laura Couto <[email protected]> * Add validation for plain string on ChatPromptTemplate Signed-off-by: Laura Couto <[email protected]> * Fix indentation on docstring Signed-off-by: Laura Couto <[email protected]> * update docstring and version Signed-off-by: Laura Couto <[email protected]> * Remove redundant part of docstring Signed-off-by: Laura Couto <[email protected]> * Add validation for dataset type Signed-off-by: Laura Couto <[email protected]> * Update docstring for _build_dataset_config Signed-off-by: Laura Couto <[email protected]> * Update docstring for _build_dataset_config Signed-off-by: Laura Couto <[email protected]> * Fix indentation on docstring Signed-off-by: Laura Couto <[email protected]> * Make dataset type parameter mandatory Signed-off-by: Laura Couto <[email protected]> * Split by period and use one last two names in dataset type validation Signed-off-by: Laura Couto <[email protected]> * Update kedro-datasets/kedro_datasets_experimental/langchain/langchain_prompt_dataset.py Co-authored-by: ElenaKhaustova <[email protected]> Signed-off-by: L. R. Couto <[email protected]> * Separate validation on build config function Signed-off-by: Laura Couto <[email protected]> * Lint? Signed-off-by: Laura Couto <[email protected]> --------- Signed-off-by: Laura Couto <[email protected]> Signed-off-by: L. R. Couto <[email protected]> Co-authored-by: ElenaKhaustova <[email protected]>
1 parent e3a1b28 commit 2a1a134

File tree

9 files changed

+680
-0
lines changed

9 files changed

+680
-0
lines changed

kedro-datasets/RELEASE.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
- Group datasets documentation according to the dependencies to clean up the nav bar.
55

6+
- Added the following new **experimental** datasets:
7+
8+
| Type | Description | Location |
9+
| ----------------------------------- | -------------------------------------------------------- | --------------------------------------- |
10+
| `langchain.LangChainPromptDataset` | Kedro dataset for loading LangChain prompts | `kedro_datasets_experimental.langchain` |
11+
612
## Bug fixes and other changes
713
- Add HTMLPreview type.
814

kedro-datasets/docs/api/kedro_datasets_experimental/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Name | Description
1111
[langchain.ChatCohereDataset](langchain.ChatCohereDataset.md) | ``ChatCohereDataset`` loads a ChatCohere `langchain` model.
1212
[langchain.ChatOpenAIDataset](langchain.ChatOpenAIDataset.md) | OpenAI dataset used to access credentials at runtime.
1313
[langchain.OpenAIEmbeddingsDataset](langchain.OpenAIEmbeddingsDataset.md) | ``OpenAIEmbeddingsDataset`` loads a OpenAIEmbeddings `langchain` model.
14+
[langchain.LangChainPromptDataset](langchain.LangChainPromptDataset.md) | ``LangChainPromptDataset`` loads a `langchain` prompt template.
1415
[netcdf.NetCDFDataset](netcdf.NetCDFDataset.md) | ``NetCDFDataset`` loads/saves data from/to a NetCDF file using an underlying filesystem (e.g.: local, S3, GCS). It uses xarray to handle the NetCDF file.
1516
[polars.PolarsDatabaseDataset](polars.PolarsDatabaseDataset.md) | ``PolarsDatabaseDataset`` implementation to access databases as Polars DataFrames. It supports reading from a SQL query and writing to a database table.
1617
[prophet.ProphetModelDataset](prophet.ProphetModelDataset.md) | ``ProphetModelDataset`` loads/saves Facebook Prophet models to a JSON file using an underlying filesystem (e.g., local, S3, GCS). It uses Prophet's built-in serialisation to handle the JSON file.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
::: kedro_datasets_experimental.langchain.LangChainPromptDataset
2+
options:
3+
members: true
4+
show_source: true

kedro-datasets/kedro_datasets_experimental/langchain/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,23 @@
77
from ._anthropic import ChatAnthropicDataset
88
from ._cohere import ChatCohereDataset
99
from ._openai import ChatOpenAIDataset, OpenAIEmbeddingsDataset
10+
from .langchain_prompt_dataset import LangChainPromptDataset
11+
1012
except (ImportError, RuntimeError):
1113
# For documentation builds that might fail due to dependency issues
1214
# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
1315
ChatAnthropicDataset: Any
1416
ChatOpenAIDataset: Any
1517
OpenAIEmbeddingsDataset: Any
1618
ChatCohereDataset: Any
19+
LangChainPromptDataset: Any
1720

1821
__getattr__, __dir__, __all__ = lazy.attach(
1922
__name__,
2023
submod_attrs={
2124
"_openai": ["ChatOpenAIDataset", "OpenAIEmbeddingsDataset"],
2225
"_anthropic": ["ChatAnthropicDataset"],
2326
"_cohere": ["ChatCohereDataset"],
27+
"langchain_prompt_dataset": ["LangChainPromptDataset"],
2428
},
2529
)
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
import json
2+
from copy import deepcopy
3+
from pathlib import Path
4+
from typing import Any, Union
5+
6+
from kedro.io import AbstractDataset, DatasetError
7+
from kedro.io.catalog_config_resolver import CREDENTIALS_KEY
8+
from kedro.io.core import get_filepath_str, parse_dataset_definition
9+
from langchain.prompts import ChatPromptTemplate, PromptTemplate
10+
11+
from kedro_datasets._typing import JSONPreview
12+
13+
14+
class LangChainPromptDataset(AbstractDataset[Union[PromptTemplate, ChatPromptTemplate], Any]): # noqa UP007
15+
"""
16+
A Kedro dataset for loading LangChain prompt templates from text, JSON, or YAML files.
17+
18+
This dataset wraps existing Kedro datasets (such as TextDataset, JSONDataset, or YAMLDataset)
19+
to load prompt configurations and convert them into LangChain `PromptTemplate` or
20+
`ChatPromptTemplate` objects.
21+
22+
### Example usage for the [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/):
23+
```yaml
24+
my_prompt:
25+
type: kedro_datasets_experimental.langchain.LangChainPromptDataset
26+
filepath: data/prompts/my_prompt.json
27+
template: PromptTemplate
28+
dataset:
29+
type: json.JSONDataset
30+
fs_args:
31+
load_args:
32+
encoding: utf-8
33+
save_args:
34+
ensure_ascii: false
35+
credentials: dev_creds
36+
metadata:
37+
kedro-viz:
38+
layer: raw
39+
```
40+
41+
### Example usage for the [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/):
42+
```python
43+
from kedro_datasets_experimental.langchain import LangChainPromptDataset
44+
45+
dataset = LangChainPromptDataset(
46+
filepath="data/prompts/my_prompt.json",
47+
template="PromptTemplate",
48+
dataset={"type": "json.JSONDataset"}
49+
)
50+
prompt = dataset.load()
51+
print(prompt.format(name="Kedro"))
52+
```
53+
"""
54+
55+
TEMPLATES = {
56+
"PromptTemplate": "_create_prompt_template",
57+
"ChatPromptTemplate": "_create_chat_prompt_template",
58+
}
59+
60+
VALID_DATASETS = {"text.TextDataset", "json.JSONDataset", "yaml.YAMLDataset"}
61+
62+
def __init__( # noqa: PLR0913
63+
self,
64+
filepath: str,
65+
template: str = "PromptTemplate",
66+
dataset: dict[str, Any] | str | None = None,
67+
credentials: dict[str, Any] | None = None,
68+
fs_args: dict[str, Any] | None = None,
69+
metadata: dict[str, Any] | None = None,
70+
**kwargs: Any,
71+
):
72+
"""
73+
Initialize the LangChain prompt dataset.
74+
75+
Args:
76+
filepath: Path to the prompt file
77+
template: Name of the LangChain template class ("PromptTemplate" or "ChatPromptTemplate")
78+
dataset: Configuration for the underlying Kedro dataset
79+
credentials: Credentials passed to the underlying dataset unless already defined
80+
fs_args: Extra arguments passed to the filesystem, if supported
81+
metadata: Arbitrary metadata
82+
**kwargs: Additional arguments (ignored)
83+
"""
84+
super().__init__()
85+
86+
self.metadata = metadata
87+
self._filepath = get_filepath_str(Path(filepath), kwargs.get("protocol"))
88+
89+
try:
90+
self._template_name = template
91+
self._create_template_function = getattr(self, self.TEMPLATES[template])
92+
except KeyError:
93+
raise DatasetError(
94+
f"Invalid template '{template}'. Must be one of: {list(self.TEMPLATES)}"
95+
)
96+
97+
# Infer dataset type if not explicitly provided
98+
dataset_config = self._build_dataset_config(dataset)
99+
100+
# Handle credentials
101+
self._credentials = deepcopy(credentials or {})
102+
self._fs_args = deepcopy(fs_args or {})
103+
104+
if self._credentials:
105+
if CREDENTIALS_KEY in dataset_config:
106+
self._logger.warning(
107+
"Top-level credentials will not propagate into the underlying dataset "
108+
"since credentials were explicitly defined in the dataset config."
109+
)
110+
else:
111+
dataset_config[CREDENTIALS_KEY] = deepcopy(self._credentials)
112+
113+
if self._fs_args:
114+
if "fs_args" in dataset_config:
115+
self._logger.warning(
116+
"Top-level fs_args will not propagate into the underlying dataset "
117+
"since fs_args were explicitly defined in the dataset config."
118+
)
119+
else:
120+
dataset_config["fs_args"] = deepcopy(self._fs_args)
121+
122+
try:
123+
dataset_class, dataset_kwargs = parse_dataset_definition(dataset_config)
124+
self._dataset = dataset_class(**dataset_kwargs)
125+
except Exception as e:
126+
raise DatasetError(f"Failed to create underlying dataset: {e}")
127+
128+
def _validate_dataset_type(self, dataset: dict[str, Any] | str | None) -> None:
129+
"""Validate that the dataset type is supported and not None."""
130+
if dataset is None:
131+
raise DatasetError(f"Underlying dataset type cannot be empty: {self._filepath}")
132+
133+
dataset_type = dataset["type"] if isinstance(dataset, dict) else str(dataset)
134+
normalized_type = ".".join(dataset_type.split(".")[-2:])
135+
if normalized_type not in self.VALID_DATASETS:
136+
raise DatasetError(
137+
f"Unsupported dataset type '{dataset_type}'. "
138+
f"Allowed dataset types are: {', '.join(self.VALID_DATASETS)}"
139+
)
140+
141+
def _build_dataset_config(self, dataset: dict[str, Any] | str | None) -> dict[str, Any]:
142+
"""
143+
Build dataset configuration.
144+
145+
Raises:
146+
DatasetError: If the dataset type is unsupported.
147+
Currently supported dataset types are: text.TextDataset, json.JSONDataset, yaml.YAMLDataset
148+
149+
Returns:
150+
dict: A normalized dataset configuration dictionary.
151+
"""
152+
self._validate_dataset_type(dataset)
153+
dataset_config = dataset if isinstance(dataset, dict) else {"type": dataset}
154+
dataset_config = deepcopy(dataset_config)
155+
dataset_config["filepath"] = self._filepath
156+
return dataset_config
157+
158+
def load(self) -> PromptTemplate | ChatPromptTemplate:
159+
"""
160+
Loads the underlying dataset and converts the data into a LangChain prompt template.
161+
162+
This method retrieves raw prompt data from the underlying dataset (e.g., a JSON or YAML file)
163+
and constructs the corresponding LangChain template — either a `PromptTemplate` or
164+
`ChatPromptTemplate` — depending on the dataset configuration.
165+
166+
Raises:
167+
DatasetError: If the dataset cannot be loaded, contains no data, or cannot be
168+
converted into the expected prompt template.
169+
170+
Returns:
171+
PromptTemplate | ChatPromptTemplate:
172+
A fully initialized LangChain prompt object created from the dataset contents.
173+
174+
Example:
175+
>>> dataset.load()
176+
ChatPromptTemplate.from_messages([
177+
("system", "You are a helpful assistant."),
178+
("human", "{input}")
179+
])
180+
"""
181+
try:
182+
raw_data = self._dataset.load()
183+
except Exception as e:
184+
raise DatasetError(f"Failed to load data from {self._filepath}: {e}")
185+
186+
if raw_data is None:
187+
raise DatasetError(f"No data loaded from {self._filepath}")
188+
189+
try:
190+
return self._create_template_function(raw_data)
191+
except Exception as e:
192+
raise DatasetError(f"Failed to create {self._template_name}: {e}")
193+
194+
def _create_prompt_template(self, raw_data: str | dict[str]) -> PromptTemplate:
195+
"""
196+
Create a `PromptTemplate` from loaded raw data.
197+
198+
This method supports either a string template or a dictionary
199+
containing the prompt configuration.
200+
201+
Args:
202+
raw_data (str | dict): Either a string representing the template,
203+
or a dictionary with keys compatible with `PromptTemplate` initialization
204+
(e.g., `template`, `input_variables`).
205+
206+
Returns:
207+
PromptTemplate: A LangChain `PromptTemplate` instance initialized
208+
with the provided template data.
209+
210+
Raises:
211+
DatasetError: If `raw_data` is not a string or dictionary.
212+
213+
Examples:
214+
>>> dataset._create_prompt_template("Hello {name}!")
215+
PromptTemplate(template='Hello {name}!', input_variables=['name'])
216+
217+
>>> dataset._create_prompt_template({
218+
... "template": "Hello {name}!",
219+
... "input_variables": ["name"]
220+
... })
221+
PromptTemplate(template='Hello {name}!', input_variables=['name'])
222+
"""
223+
if isinstance(raw_data, str):
224+
return PromptTemplate.from_template(raw_data)
225+
226+
if isinstance(raw_data, dict):
227+
return PromptTemplate(**raw_data)
228+
229+
raise DatasetError(f"Unsupported data type for PromptTemplate: {type(raw_data)}")
230+
231+
def _validate_chat_prompt_data(self, data: dict | list[tuple[str, str]]) -> dict | list[tuple[str, str]]:
232+
"""
233+
Validate that chat prompt data exists and is not empty.
234+
Raises an error if data is a plain string, which is only compatible with PromptTemplate.
235+
236+
Returns validated and unpacked messages as a dictionary or a list of tuples.
237+
238+
Raises:
239+
DatasetError: If the data is empty or is a plain string.
240+
"""
241+
if isinstance(data, str):
242+
raise DatasetError(
243+
"Plain string data is only supported for PromptTemplate, not ChatPromptTemplate."
244+
)
245+
246+
messages = data.get("messages") if isinstance(data, dict) else data
247+
if not messages:
248+
raise DatasetError(
249+
"ChatPromptTemplate requires a non-empty list of messages"
250+
)
251+
252+
return messages
253+
254+
def _create_chat_prompt_template(self, data: dict | list[tuple[str, str]]) -> ChatPromptTemplate:
255+
"""
256+
Create a `ChatPromptTemplate` from validated chat data.
257+
258+
Supports either:
259+
- A dictionary in the LangChain chat JSON format (`{"messages": [{"role": "...", "content": "..."}]}`),
260+
- Or a list of `(role, content)` tuples.
261+
262+
Args:
263+
data (dict | list[tuple[str, str]]): Chat prompt data to validate and transform.
264+
265+
Returns:
266+
ChatPromptTemplate: A LangChain `ChatPromptTemplate` instance.
267+
268+
Raises:
269+
DatasetError: If cannot be used to create a `ChatPromptTemplate`.
270+
271+
Examples:
272+
>>> dataset._create_chat_prompt_template({
273+
... "messages": [
274+
... {"role": "system", "content": "You are a helpful assistant."},
275+
... {"role": "user", "content": "Hello, who are you?"}
276+
... ]
277+
... })
278+
ChatPromptTemplate(messages=[...])
279+
280+
>>> dataset._create_chat_prompt_template([
281+
... ("user", "Hello"),
282+
... ("ai", "Hi there!")
283+
... ])
284+
ChatPromptTemplate(messages=[...])
285+
"""
286+
messages = self._validate_chat_prompt_data(data)
287+
return ChatPromptTemplate.from_messages(messages)
288+
289+
def save(self, data: Any) -> None:
290+
raise DatasetError("Saving is not supported for LangChainPromptDataset")
291+
292+
def _describe(self) -> dict[str, Any]:
293+
clean_config = {
294+
k: v for k, v in getattr(self._dataset, "_config", {}).items() if k != CREDENTIALS_KEY
295+
}
296+
return {
297+
"path": self._filepath,
298+
"template": self._template_name,
299+
"underlying_dataset": self._dataset.__class__.__name__,
300+
"dataset_config": clean_config,
301+
}
302+
303+
def _exists(self) -> bool:
304+
return self._dataset._exists() if hasattr(self._dataset, "_exists") else True
305+
306+
def preview(self) -> JSONPreview:
307+
"""
308+
Generate a JSON-compatible preview of the underlying prompt data for Kedro-Viz.
309+
310+
Returns:
311+
JSONPreview:
312+
A Kedro-Viz-compatible object containing a serialized JSON string of the
313+
processed data. If an exception occurs during processing, the returned
314+
JSONPreview contains an error message instead of the dataset content.
315+
Example:
316+
>>> dataset.preview()
317+
JSONPreview('{"messages": [{"role": "system", "content": "You are..."}]}')
318+
"""
319+
try:
320+
data = self._dataset.load()
321+
322+
if isinstance(data, str):
323+
# Wrap plain text in a dictionary or Viz doesn't render it
324+
data = {"text": data}
325+
326+
return JSONPreview(json.dumps(data))
327+
328+
except Exception as e:
329+
return JSONPreview(f"Error generating preview: {e}")

kedro-datasets/kedro_datasets_experimental/tests/langchain/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)