Skip to content

Commit 7129f23

Browse files
authored
output parser serialization (#758)
1 parent f273c50 commit 7129f23

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

langchain/prompts/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,24 @@ def check_valid_template(
4848
raise ValueError("Invalid prompt schema.")
4949

5050

51-
class BaseOutputParser(ABC):
51+
class BaseOutputParser(BaseModel, ABC):
5252
"""Class to parse the output of an LLM call."""
5353

5454
@abstractmethod
5555
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
5656
"""Parse the output of an LLM call."""
5757

58+
@property
59+
def _type(self) -> str:
60+
"""Return the type key."""
61+
raise NotImplementedError
62+
63+
def dict(self, **kwargs: Any) -> Dict:
64+
"""Return dictionary representation of output parser."""
65+
output_parser_dict = super().dict()
66+
output_parser_dict["_type"] = self._type
67+
return output_parser_dict
68+
5869

5970
class ListOutputParser(BaseOutputParser):
6071
"""Class to parse the output of an LLM call to a list."""
@@ -79,6 +90,11 @@ class RegexParser(BaseOutputParser, BaseModel):
7990
output_keys: List[str]
8091
default_output_key: Optional[str] = None
8192

93+
@property
94+
def _type(self) -> str:
95+
"""Return the type key."""
96+
return "regex_parser"
97+
8298
def parse(self, text: str) -> Dict[str, str]:
8399
"""Parse the output of an LLM call."""
84100
match = re.search(self.regex, text)
@@ -142,7 +158,7 @@ def _prompt_type(self) -> str:
142158

143159
def dict(self, **kwargs: Any) -> Dict:
144160
"""Return dictionary representation of prompt."""
145-
prompt_dict = super().dict()
161+
prompt_dict = super().dict(**kwargs)
146162
prompt_dict["_type"] = self._prompt_type
147163
return prompt_dict
148164

langchain/prompts/loading.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import requests
1010
import yaml
1111

12-
from langchain.prompts.base import BasePromptTemplate
12+
from langchain.prompts.base import BasePromptTemplate, RegexParser
1313
from langchain.prompts.few_shot import FewShotPromptTemplate
1414
from langchain.prompts.prompt import PromptTemplate
1515

@@ -69,6 +69,20 @@ def _load_examples(config: dict) -> dict:
6969
return config
7070

7171

72+
def _load_output_parser(config: dict) -> dict:
73+
"""Load output parser."""
74+
if "output_parser" in config:
75+
if config["output_parser"] is not None:
76+
_config = config["output_parser"]
77+
output_parser_type = _config["_type"]
78+
if output_parser_type == "regex_parser":
79+
output_parser = RegexParser(**_config)
80+
else:
81+
raise ValueError(f"Unsupported output parser {output_parser_type}")
82+
config["output_parser"] = output_parser
83+
return config
84+
85+
7286
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
7387
"""Load the few shot prompt from the config."""
7488
# Load the suffix and prefix templates.
@@ -86,13 +100,15 @@ def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
86100
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
87101
# Load the examples.
88102
config = _load_examples(config)
103+
config = _load_output_parser(config)
89104
return FewShotPromptTemplate(**config)
90105

91106

92107
def _load_prompt(config: dict) -> PromptTemplate:
93108
"""Load the prompt template from config."""
94109
# Load the template from disk if necessary.
95110
config = _load_template("template", config)
111+
config = _load_output_parser(config)
96112
return PromptTemplate(**config)
97113

98114

0 commit comments

Comments
 (0)