|
| 1 | +""" |
| 2 | +MIT License |
| 3 | +
|
| 4 | +Copyright (c) 2025 LangChain |
| 5 | +
|
| 6 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | +of this software and associated documentation files (the "Software"), to deal |
| 8 | +in the Software without restriction, including without limitation the rights |
| 9 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | +copies of the Software, and to permit persons to whom the Software is |
| 11 | +furnished to do so, subject to the following conditions: |
| 12 | +
|
| 13 | +The above copyright notice and this permission notice shall be included in all |
| 14 | +copies or substantial portions of the Software. |
| 15 | +
|
| 16 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | +SOFTWARE. |
| 23 | +""" |
| 24 | + |
| 25 | +from typing import Any, Dict, Iterator, List, Optional |
| 26 | + |
| 27 | +from langchain_core.callbacks import CallbackManagerForLLMRun |
| 28 | +from langchain_core.language_models import BaseChatModel |
| 29 | +from langchain_core.messages import AIMessage, BaseMessage |
| 30 | +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult |
| 31 | +from pydantic import Field |
| 32 | + |
| 33 | + |
| 34 | +class LangChainChatModel(BaseChatModel): |
| 35 | + """A custom chat model that echoes the first `parrot_buffer_length` characters |
| 36 | + of the input. |
| 37 | +
|
| 38 | + When contributing an implementation to LangChain, carefully document |
| 39 | + the model including the initialization parameters, include |
| 40 | + an example of how to initialize the model and include any relevant |
| 41 | + links to the underlying models documentation or API. |
| 42 | +
|
| 43 | + Example: |
| 44 | +
|
| 45 | + .. code-block:: python |
| 46 | +
|
| 47 | + model = LangChainChatModel(parrot_buffer_length=2, model="bird-brain-001") |
| 48 | + result = model.invoke([HumanMessage(content="hello")]) |
| 49 | + result = model.batch([[HumanMessage(content="hello")], |
| 50 | + [HumanMessage(content="world")]]) |
| 51 | + """ |
| 52 | + |
| 53 | + model_name: str = Field(alias="model") |
| 54 | + temperature: Optional[float] = None |
| 55 | + max_tokens: Optional[int] = None |
| 56 | + timeout: Optional[int] = None |
| 57 | + stop: Optional[List[str]] = None |
| 58 | + async_server_manager: Optional[Any] = None |
| 59 | + max_retries: int = 2 |
| 60 | + |
| 61 | + def _generate( |
| 62 | + self, |
| 63 | + messages: List[BaseMessage], |
| 64 | + stop: Optional[List[str]] = None, |
| 65 | + run_manager: Optional[CallbackManagerForLLMRun] = None, |
| 66 | + **kwargs: Any, |
| 67 | + ) -> ChatResult: |
| 68 | + """Override the _generate method to implement the chat model logic. |
| 69 | +
|
| 70 | + This can be a call to an API, a call to a local model, or any other |
| 71 | + implementation that generates a response to the input prompt. |
| 72 | +
|
| 73 | + Args: |
| 74 | + messages: the prompt composed of a list of messages. |
| 75 | + stop: a list of strings on which the model should stop generating. |
| 76 | + If generation stops due to a stop token, the stop token itself |
| 77 | + SHOULD BE INCLUDED as part of the output. This is not enforced |
| 78 | + across models right now, but it's a good practice to follow since |
| 79 | + it makes it much easier to parse the output of the model |
| 80 | + downstream and understand why generation stopped. |
| 81 | + run_manager: A run manager with callbacks for the LLM. |
| 82 | + """ |
| 83 | + self.async_server_manager.generate(messages, stop, run_manager, **kwargs) |
| 84 | + tokens = last_message.content[: self.parrot_buffer_length] |
| 85 | + ct_input_tokens = sum(len(message.content) for message in messages) |
| 86 | + ct_output_tokens = len(tokens) |
| 87 | + message = AIMessage( |
| 88 | + content=tokens, |
| 89 | + additional_kwargs={}, # Used to add additional payload to the message |
| 90 | + response_metadata={ # Use for response metadata |
| 91 | + "time_in_seconds": 3, |
| 92 | + "model_name": self.model_name, |
| 93 | + }, |
| 94 | + usage_metadata={ |
| 95 | + "input_tokens": ct_input_tokens, |
| 96 | + "output_tokens": ct_output_tokens, |
| 97 | + "total_tokens": ct_input_tokens + ct_output_tokens, |
| 98 | + }, |
| 99 | + ) |
| 100 | + ## |
| 101 | + |
| 102 | + generation = ChatGeneration(message=message) |
| 103 | + return ChatResult(generations=[generation]) |
| 104 | + |
| 105 | + def _stream( |
| 106 | + self, |
| 107 | + messages: List[BaseMessage], |
| 108 | + stop: Optional[List[str]] = None, |
| 109 | + run_manager: Optional[CallbackManagerForLLMRun] = None, |
| 110 | + **kwargs: Any, |
| 111 | + ) -> Iterator[ChatGenerationChunk]: |
| 112 | + """Stream the output of the model. |
| 113 | +
|
| 114 | + This method should be implemented if the model can generate output |
| 115 | + in a streaming fashion. If the model does not support streaming, |
| 116 | + do not implement it. In that case streaming requests will be automatically |
| 117 | + handled by the _generate method. |
| 118 | +
|
| 119 | + Args: |
| 120 | + messages: the prompt composed of a list of messages. |
| 121 | + stop: a list of strings on which the model should stop generating. |
| 122 | + If generation stops due to a stop token, the stop token itself |
| 123 | + SHOULD BE INCLUDED as part of the output. This is not enforced |
| 124 | + across models right now, but it's a good practice to follow since |
| 125 | + it makes it much easier to parse the output of the model |
| 126 | + downstream and understand why generation stopped. |
| 127 | + run_manager: A run manager with callbacks for the LLM. |
| 128 | + """ |
| 129 | + raise NotImplementedError("Streaming is not implemented for this model. Please implement the _stream method.") |
| 130 | + |
| 131 | + @property |
| 132 | + def _llm_type(self) -> str: |
| 133 | + """Get the type of language model used by this chat model.""" |
| 134 | + return "echoing-chat-model-advanced" |
| 135 | + |
| 136 | + @property |
| 137 | + def _identifying_params(self) -> Dict[str, Any]: |
| 138 | + """Return a dictionary of identifying parameters. |
| 139 | +
|
| 140 | + This information is used by the LangChain callback system, which |
| 141 | + is used for tracing purposes make it possible to monitor LLMs. |
| 142 | + """ |
| 143 | + return { |
| 144 | + # The model name allows users to specify custom token counting |
| 145 | + # rules in LLM monitoring applications (e.g., in LangSmith users |
| 146 | + # can provide per token pricing for their model and monitor |
| 147 | + # costs for the given LLM.) |
| 148 | + "model_name": self.model_name, |
| 149 | + } |
0 commit comments