55from langchain .prompts import PromptTemplate
66from langchain_core .output_parsers import JsonOutputParser
77from langchain_core .runnables import RunnableParallel
8- from langchain_openai import ChatOpenAI
8+ from langchain_openai import ChatOpenAI , AzureChatOpenAI
9+ from langchain_mistralai import ChatMistralAI
10+ from langchain_anthropic import ChatAnthropic
11+ from langchain_groq import ChatGroq
12+ from langchain_fireworks import ChatFireworks
13+ from langchain_google_vertexai import ChatVertexAI
914from langchain_community .chat_models import ChatOllama
1015from tqdm import tqdm
1116from ..utils .logging import get_logger
@@ -88,7 +93,9 @@ def execute(self, state: dict) -> dict:
8893 # Initialize the output parser
8994 if self .node_config .get ("schema" , None ) is not None :
9095 output_parser = JsonOutputParser (pydantic_object = self .node_config ["schema" ])
91- if isinstance (self .llm_model , ChatOpenAI ) and (self .llm_model .model_name == "gpt-4o-mini" or self .llm_model .model_name == "gpt-4o-2024-08-06" ):
96+
97+ # Use built-in structured output for providers that allow it
98+ if isinstance (self .llm_model , (ChatOpenAI , ChatMistralAI , ChatAnthropic , ChatFireworks , ChatGroq , ChatVertexAI )):
9299 self .llm_model = self .llm_model .with_structured_output (
93100 schema = self .node_config ["schema" ],
94101 method = "json_schema" )
@@ -98,7 +105,7 @@ def execute(self, state: dict) -> dict:
98105
99106 format_instructions = output_parser .get_format_instructions ()
100107
101- if isinstance (self .llm_model , ChatOpenAI ) and not self .script_creator or self .force and not self .script_creator or self .is_md_scraper :
108+ if isinstance (self .llm_model , ( ChatOpenAI , AzureChatOpenAI ) ) and not self .script_creator or self .force and not self .script_creator or self .is_md_scraper :
102109 template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
103110 template_chunks_prompt = TEMPLATE_CHUNKS_MD
104111 template_merge_prompt = TEMPLATE_MERGE_MD
0 commit comments