diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index d65c1add..01d834b8 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -3,6 +3,7 @@ """ from typing import List, Optional from json.decoder import JSONDecodeError +import time from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel @@ -12,14 +13,13 @@ from tqdm import tqdm from .base_node import BaseNode from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser +from requests.exceptions import Timeout +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks import get_openai_callback from ..prompts import ( TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD ) -from langchain.callbacks.manager import CallbackManager -from langchain.callbacks import get_openai_callback -from requests.exceptions import Timeout -import time class GenerateAnswerNode(BaseNode): """ @@ -82,11 +82,8 @@ def execute(self, state: dict) -> dict: if self.node_config.get("schema", None) is not None: if isinstance(self.llm_model, ChatOpenAI): - self.llm_model = self.llm_model.with_structured_output( - schema=self.node_config["schema"] - ) - output_parser = get_structured_output_parser(self.node_config["schema"]) - format_instructions = "NA" + output_parser = get_pydantic_output_parser(self.node_config["schema"]) + format_instructions = output_parser.get_format_instructions() else: if not isinstance(self.llm_model, ChatBedrock): output_parser = get_pydantic_output_parser(self.node_config["schema"])