Skip to content

Commit 3f2ea5c

Browse files
authored
Harrison/load from hub (#580)
1 parent f74ce7a commit 3f2ea5c

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

langchain/prompts/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Prompt template classes."""
22
from langchain.prompts.base import BasePromptTemplate
33
from langchain.prompts.few_shot import FewShotPromptTemplate
4-
from langchain.prompts.loading import load_prompt
4+
from langchain.prompts.loading import load_from_hub, load_prompt
55
from langchain.prompts.prompt import Prompt, PromptTemplate
66

77
__all__ = [
@@ -10,4 +10,5 @@
1010
"PromptTemplate",
1111
"FewShotPromptTemplate",
1212
"Prompt",
13+
"load_from_hub",
1314
]

langchain/prompts/loading.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Load prompts from disk."""
2+
import importlib
23
import json
4+
import tempfile
35
from pathlib import Path
46
from typing import Union
57

8+
import requests
69
import yaml
710

811
from langchain.prompts.base import BasePromptTemplate
@@ -97,7 +100,38 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
97100
elif file_path.suffix == ".yaml":
98101
with open(file_path, "r") as f:
99102
config = yaml.safe_load(f)
103+
elif file_path.suffix == ".py":
104+
spec = importlib.util.spec_from_loader(
105+
"prompt", loader=None, origin=str(file_path)
106+
)
107+
if spec is None:
108+
raise ValueError("could not load spec")
109+
helper = importlib.util.module_from_spec(spec)
110+
with open(file_path, "rb") as f:
111+
exec(f.read(), helper.__dict__)
112+
if not isinstance(helper.PROMPT, BasePromptTemplate):
113+
raise ValueError("Did not get object of type BasePromptTemplate.")
114+
return helper.PROMPT
100115
else:
101-
raise ValueError
116+
raise ValueError(f"Got unsupported file type {file_path.suffix}")
102117
# Load the prompt from the config now.
103118
return load_prompt_from_config(config)
119+
120+
121+
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
122+
123+
124+
def load_from_hub(path: str) -> BasePromptTemplate:
125+
"""Load prompt from hub."""
126+
suffix = path.split(".")[-1]
127+
if suffix not in {"py", "json", "yaml"}:
128+
raise ValueError("Unsupported file type.")
129+
full_url = URL_BASE + path
130+
r = requests.get(full_url)
131+
if r.status_code != 200:
132+
raise ValueError(f"Could not find file at {full_url}")
133+
with tempfile.TemporaryDirectory() as tmpdirname:
134+
file = tmpdirname + "/prompt." + suffix
135+
with open(file, "wb") as f:
136+
f.write(r.content)
137+
return load_prompt(file)

0 commit comments

Comments
 (0)