|
3 | 3 | """ |
4 | 4 | from typing import List, Optional |
5 | 5 | from json.decoder import JSONDecodeError |
| 6 | +import time |
6 | 7 | from langchain.prompts import PromptTemplate |
7 | 8 | from langchain_core.output_parsers import JsonOutputParser |
8 | 9 | from langchain_core.runnables import RunnableParallel |
|
12 | 13 | from tqdm import tqdm |
13 | 14 | from .base_node import BaseNode |
14 | 15 | from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser |
| 16 | +from requests.exceptions import Timeout |
| 17 | +from langchain.callbacks.manager import CallbackManager |
| 18 | +from langchain.callbacks import get_openai_callback |
15 | 19 | from ..prompts import ( |
16 | 20 | TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, |
17 | 21 | TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD |
18 | 22 | ) |
19 | | -from langchain.callbacks.manager import CallbackManager |
20 | | -from langchain.callbacks import get_openai_callback |
21 | | -from requests.exceptions import Timeout |
22 | | -import time |
23 | 23 |
|
24 | 24 | class GenerateAnswerNode(BaseNode): |
25 | 25 | """ |
@@ -82,11 +82,8 @@ def execute(self, state: dict) -> dict: |
82 | 82 |
|
83 | 83 | if self.node_config.get("schema", None) is not None: |
84 | 84 | if isinstance(self.llm_model, ChatOpenAI): |
85 | | - self.llm_model = self.llm_model.with_structured_output( |
86 | | - schema=self.node_config["schema"] |
87 | | - ) |
88 | | - output_parser = get_structured_output_parser(self.node_config["schema"]) |
89 | | - format_instructions = "NA" |
| 85 | + output_parser = get_pydantic_output_parser(self.node_config["schema"]) |
| 86 | + format_instructions = output_parser.get_format_instructions() |
90 | 87 | else: |
91 | 88 | if not isinstance(self.llm_model, ChatBedrock): |
92 | 89 | output_parser = get_pydantic_output_parser(self.node_config["schema"]) |
|
0 commit comments