Skip to content

Commit 4aba0ab

Browse files
authored
added common prompt load method (#699)
Co-authored-by: scadEfUr
1 parent 36b6b3c commit 4aba0ab

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

langchain/prompts/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Prompt template classes."""
22
from langchain.prompts.base import BasePromptTemplate
33
from langchain.prompts.few_shot import FewShotPromptTemplate
4-
from langchain.prompts.loading import load_from_hub, load_prompt
4+
from langchain.prompts.loading import load_prompt
55
from langchain.prompts.prompt import Prompt, PromptTemplate
66

77
__all__ = [
@@ -10,5 +10,4 @@
1010
"PromptTemplate",
1111
"FewShotPromptTemplate",
1212
"Prompt",
13-
"load_from_hub",
1413
]

langchain/prompts/loading.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from langchain.prompts.few_shot import FewShotPromptTemplate
1313
from langchain.prompts.prompt import PromptTemplate
1414

15+
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
16+
1517

1618
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
1719
"""Get the right type from the config and load it accordingly."""
@@ -93,7 +95,16 @@ def _load_prompt(config: dict) -> PromptTemplate:
9395
return PromptTemplate(**config)
9496

9597

96-
def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
98+
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
99+
"""Unified method for loading a prompt from LangChainHub or local fs."""
100+
if isinstance(path, str) and path.startswith("lc://prompts"):
101+
path = path.lstrip("lc://prompts/")
102+
return _load_from_hub(path)
103+
else:
104+
return _load_prompt_from_file(path)
105+
106+
107+
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
97108
"""Load prompt from file."""
98109
# Convert file to Path object.
99110
if isinstance(file, str):
@@ -125,10 +136,7 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
125136
return load_prompt_from_config(config)
126137

127138

128-
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
129-
130-
131-
def load_from_hub(path: str) -> BasePromptTemplate:
139+
def _load_from_hub(path: str) -> BasePromptTemplate:
132140
"""Load prompt from hub."""
133141
suffix = path.split(".")[-1]
134142
if suffix not in {"py", "json", "yaml"}:
@@ -141,4 +149,4 @@ def load_from_hub(path: str) -> BasePromptTemplate:
141149
file = tmpdirname + "/prompt." + suffix
142150
with open(file, "wb") as f:
143151
f.write(r.content)
144-
return load_prompt(file)
152+
return _load_prompt_from_file(file)

0 commit comments

Comments
 (0)