Skip to content

Commit f273c50

Browse files
authored
add loading chains from hub (#757)
1 parent 1b89a43 commit f273c50

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

langchain/chains/loading.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
"""Functionality for loading chains."""
22
import json
3+
import os
4+
import tempfile
35
from pathlib import Path
46
from typing import Union
57

8+
import requests
69
import yaml
710

811
from langchain.chains.base import Chain
912
from langchain.chains.llm import LLMChain
1013
from langchain.llms.loading import load_llm, load_llm_from_config
1114
from langchain.prompts.loading import load_prompt, load_prompt_from_config
1215

16+
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
17+
1318

1419
def _load_llm_chain(config: dict) -> LLMChain:
1520
"""Load LLM chain from config dict."""
@@ -48,7 +53,16 @@ def load_chain_from_config(config: dict) -> Chain:
4853
return chain_loader(config)
4954

5055

51-
def load_chain(file: Union[str, Path]) -> Chain:
56+
def load_chain(path: Union[str, Path]) -> Chain:
57+
"""Unified method for loading a chain from LangChainHub or local fs."""
58+
if isinstance(path, str) and path.startswith("lc://chains"):
59+
path = os.path.relpath(path, "lc://chains/")
60+
return _load_from_hub(path)
61+
else:
62+
return _load_chain_from_file(path)
63+
64+
65+
def _load_chain_from_file(file: Union[str, Path]) -> Chain:
5266
"""Load chain from file."""
5367
# Convert file to Path object.
5468
if isinstance(file, str):
@@ -66,3 +80,19 @@ def load_chain(file: Union[str, Path]) -> Chain:
6680
raise ValueError("File type must be json or yaml")
6781
# Load the chain from the config now.
6882
return load_chain_from_config(config)
83+
84+
85+
def _load_from_hub(path: str) -> Chain:
86+
"""Load chain from hub."""
87+
suffix = path.split(".")[-1]
88+
if suffix not in {"json", "yaml"}:
89+
raise ValueError("Unsupported file type.")
90+
full_url = URL_BASE + path
91+
r = requests.get(full_url)
92+
if r.status_code != 200:
93+
raise ValueError(f"Could not find file at {full_url}")
94+
with tempfile.TemporaryDirectory() as tmpdirname:
95+
file = tmpdirname + "/chain." + suffix
96+
with open(file, "wb") as f:
97+
f.write(r.content)
98+
return _load_chain_from_file(file)

0 commit comments

Comments
 (0)