|
1 | | -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union |
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import ( |
| 4 | + TYPE_CHECKING, |
| 5 | + Any, |
| 6 | + Dict, |
| 7 | + Generator, |
| 8 | + Iterator, |
| 9 | + List, |
| 10 | + Mapping, |
| 11 | + Optional, |
| 12 | + Union, |
| 13 | +) |
2 | 14 |
|
3 | 15 | from langchain_core.callbacks import CallbackManagerForLLMRun |
4 | 16 | from langchain_core.language_models.llms import LLM |
| 17 | +from langchain_core.outputs import GenerationChunk |
5 | 18 |
|
6 | 19 | if TYPE_CHECKING: |
7 | 20 | from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle |
@@ -73,6 +86,26 @@ class Xinference(LLM): |
73 | 86 | generate_config={"max_tokens": 1024, "stream": True}, |
74 | 87 | ) |
75 | 88 |
|
| 89 | + Example: |
| 90 | +
|
| 91 | + .. code-block:: python |
| 92 | +
|
| 93 | + from langchain_community.llms import Xinference |
| 94 | + from langchain.prompts import PromptTemplate |
| 95 | +
|
| 96 | + llm = Xinference( |
| 97 | + server_url="http://0.0.0.0:9997", |
| 98 | + model_uid={model_uid}, # replace model_uid with the model UID return from launching the model |
| 99 | + stream=True |
| 100 | + ) |
| 101 | + prompt = PromptTemplate( |
| 102 | + input=['country'], |
| 103 | + template="Q: where can we visit in the capital of {country}? A:" |
| 104 | + ) |
| 105 | + chain = prompt | llm |
| 106 | + chain.stream(input={'country': 'France'}) |
| 107 | +
|
| 108 | +
|
76 | 109 | To view all the supported builtin models, run: |
77 | 110 |
|
78 | 111 | .. code-block:: bash |
@@ -216,3 +249,59 @@ def _stream_generate( |
216 | 249 | token=token, verbose=self.verbose, log_probs=log_probs |
217 | 250 | ) |
218 | 251 | yield token |
| 252 | + |
| 253 | + def _stream( |
| 254 | + self, |
| 255 | + prompt: str, |
| 256 | + stop: Optional[List[str]] = None, |
| 257 | + run_manager: Optional[CallbackManagerForLLMRun] = None, |
| 258 | + **kwargs: Any, |
| 259 | + ) -> Iterator[GenerationChunk]: |
| 260 | + generate_config = kwargs.get("generate_config", {}) |
| 261 | + generate_config = {**self.model_kwargs, **generate_config} |
| 262 | + if stop: |
| 263 | + generate_config["stop"] = stop |
| 264 | + for stream_resp in self._create_generate_stream(prompt, generate_config): |
| 265 | + if stream_resp: |
| 266 | + chunk = self._stream_response_to_generation_chunk(stream_resp) |
| 267 | + if run_manager: |
| 268 | + run_manager.on_llm_new_token( |
| 269 | + chunk.text, |
| 270 | + verbose=self.verbose, |
| 271 | + ) |
| 272 | + yield chunk |
| 273 | + |
| 274 | + def _create_generate_stream( |
| 275 | + self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None |
| 276 | + ) -> Iterator[str]: |
| 277 | + if self.client is None: |
| 278 | + raise ValueError("Client is not initialized!") |
| 279 | + model = self.client.get_model(self.model_uid) |
| 280 | + yield from model.generate(prompt=prompt, generate_config=generate_config) |
| 281 | + |
| 282 | + @staticmethod |
| 283 | + def _stream_response_to_generation_chunk( |
| 284 | + stream_response: str, |
| 285 | + ) -> GenerationChunk: |
| 286 | + """Convert a stream response to a generation chunk.""" |
| 287 | + token = "" |
| 288 | + if isinstance(stream_response, dict): |
| 289 | + choices = stream_response.get("choices", []) |
| 290 | + if choices: |
| 291 | + choice = choices[0] |
| 292 | + if isinstance(choice, dict): |
| 293 | + token = choice.get("text", "") |
| 294 | + |
| 295 | + return GenerationChunk( |
| 296 | + text=token, |
| 297 | + generation_info=dict( |
| 298 | + finish_reason=choice.get("finish_reason", None), |
| 299 | + logprobs=choice.get("logprobs", None), |
| 300 | + ), |
| 301 | + ) |
| 302 | + else: |
| 303 | + raise TypeError("choice type error!") |
| 304 | + else: |
| 305 | + return GenerationChunk(text=token) |
| 306 | + else: |
| 307 | + raise TypeError("stream_response type error!") |
0 commit comments