diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index bfd37ea58835a..47dbc193b5f6f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -64,7 +64,7 @@ from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough -from langchain_core.runnables.config import ensure_config, run_in_executor +from langchain_core.runnables.config import ensure_config, run_in_executor, get_executor_for_config from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.utils.function_calling import ( convert_to_json_schema, @@ -904,30 +904,61 @@ def generate( run_id=run_id, batch_size=len(messages), ) - results = [] input_messages = [ _normalize_messages(message_list) for message_list in messages ] - for i, m in enumerate(input_messages): + if len(input_messages) == 1: try: - results.append( + results = [ self._generate_with_cache( - m, + input_messages[0], stop=stop, - run_manager=run_managers[i] if run_managers else None, + run_manager=run_managers[0] if run_managers else None, **kwargs, ) - ) + ] except BaseException as e: if run_managers: generations_with_error_metadata = _generate_response_from_error(e) - run_managers[i].on_llm_error( + run_managers[0].on_llm_error( e, response=LLMResult( generations=[generations_with_error_metadata] ), ) raise + else: + def _invoke(index_and_message: tuple[int, list[BaseMessage]]): + i, m = index_and_message + try: + return self._generate_with_cache( + m, + stop=stop, + run_manager=run_managers[i] if run_managers else None, + **kwargs, + ) + except BaseException as e: + return (i, e) + + with get_executor_for_config(None) as executor: + mapped = list( + executor.map(_invoke, list(enumerate(input_messages))) + ) + results = [] + for i, res in enumerate(mapped): + if isinstance(res, tuple) and isinstance(res[1], BaseException): + if run_managers: + generations_with_error_metadata = _generate_response_from_error( + res[1] + ) + run_managers[i].on_llm_error( + res[1], + response=LLMResult( + generations=[generations_with_error_metadata] + ), + ) + raise res[1] + results.append(res) flattened_outputs = [ LLMResult(generations=[res.generations], llm_output=res.llm_output) for res in results