1
1
"""Functionality for loading chains."""
2
2
import json
3
+ import os
4
+ import tempfile
3
5
from pathlib import Path
4
6
from typing import Union
5
7
8
+ import requests
6
9
import yaml
7
10
8
11
from langchain .chains .base import Chain
9
12
from langchain .chains .llm import LLMChain
10
13
from langchain .llms .loading import load_llm , load_llm_from_config
11
14
from langchain .prompts .loading import load_prompt , load_prompt_from_config
12
15
16
+ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
17
+
13
18
14
19
def _load_llm_chain (config : dict ) -> LLMChain :
15
20
"""Load LLM chain from config dict."""
@@ -48,7 +53,16 @@ def load_chain_from_config(config: dict) -> Chain:
48
53
return chain_loader (config )
49
54
50
55
51
- def load_chain (file : Union [str , Path ]) -> Chain :
56
+ def load_chain (path : Union [str , Path ]) -> Chain :
57
+ """Unified method for loading a chain from LangChainHub or local fs."""
58
+ if isinstance (path , str ) and path .startswith ("lc://chains" ):
59
+ path = os .path .relpath (path , "lc://chains/" )
60
+ return _load_from_hub (path )
61
+ else :
62
+ return _load_chain_from_file (path )
63
+
64
+
65
+ def _load_chain_from_file (file : Union [str , Path ]) -> Chain :
52
66
"""Load chain from file."""
53
67
# Convert file to Path object.
54
68
if isinstance (file , str ):
@@ -66,3 +80,19 @@ def load_chain(file: Union[str, Path]) -> Chain:
66
80
raise ValueError ("File type must be json or yaml" )
67
81
# Load the chain from the config now.
68
82
return load_chain_from_config (config )
83
+
84
+
85
+ def _load_from_hub (path : str ) -> Chain :
86
+ """Load chain from hub."""
87
+ suffix = path .split ("." )[- 1 ]
88
+ if suffix not in {"json" , "yaml" }:
89
+ raise ValueError ("Unsupported file type." )
90
+ full_url = URL_BASE + path
91
+ r = requests .get (full_url )
92
+ if r .status_code != 200 :
93
+ raise ValueError (f"Could not find file at { full_url } " )
94
+ with tempfile .TemporaryDirectory () as tmpdirname :
95
+ file = tmpdirname + "/chain." + suffix
96
+ with open (file , "wb" ) as f :
97
+ f .write (r .content )
98
+ return _load_chain_from_file (file )
0 commit comments