|
7 | 7 |
|
8 | 8 | # Imports from Langchain |
9 | 9 | from langchain.prompts import PromptTemplate |
10 | | -from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser |
| 10 | +from langchain_core.output_parsers import JsonOutputParser |
11 | 11 | from langchain_core.runnables import RunnableParallel |
12 | 12 | from tqdm import tqdm |
13 | 13 |
|
| 14 | + |
14 | 15 | from ..utils.logging import get_logger |
15 | 16 | from ..models import Ollama |
16 | 17 | # Imports from the library |
@@ -81,8 +82,8 @@ def execute(self, state: dict) -> dict: |
81 | 82 | doc = input_data[1] |
82 | 83 |
|
83 | 84 | # Initialize the output parser |
84 | | - if self.node_config.get("schema",None) is not None: |
85 | | - output_parser = PydanticOutputParser(pydantic_object=self.node_config.get("schema", None)) |
| 85 | + if self.node_config.get("schema", None) is not None: |
| 86 | + output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) |
86 | 87 | else: |
87 | 88 | output_parser = JsonOutputParser() |
88 | 89 |
|
@@ -129,9 +130,6 @@ def execute(self, state: dict) -> dict: |
129 | 130 | single_chain = list(chains_dict.values())[0] |
130 | 131 | answer = single_chain.invoke({"question": user_prompt}) |
131 | 132 |
|
132 | | - if type(answer) == PydanticOutputParser: |
133 | | - answer = answer.model_dump() |
134 | | - |
135 | 133 | # Update the state with the generated answer |
136 | 134 | state.update({self.output[0]: answer}) |
137 | 135 | return state |
0 commit comments