Skip to content

Commit 1cd4d8d

Browse files
authored
[langchain_community.llms.xinference]: Rewrite _stream() method and support stream() method in xinference.py (#29259)
- [ ] **PR title**:[langchain_community.llms.xinference]: Rewrite _stream() method and support stream() method in xinference.py - [ ] **PR message**: Rewrite the _stream method so that the chain.stream() can be used to return data streams. chain = prompt | llm chain.stream(input=user_input) - [ ] **tests**: from langchain_community.llms import Xinference from langchain.prompts import PromptTemplate llm = Xinference( server_url="http://0.0.0.0:9997", # replace your xinference server url model_uid={model_uid} # replace model_uid with the model UID return from launching the model stream = True ) prompt = PromptTemplate(input=['country'], template="Q: where can we visit in the capital of {country}? A:") chain = prompt | llm chain.stream(input={'country': 'France'})
1 parent d4b9404 commit 1cd4d8d

File tree

1 file changed

+90
-1
lines changed

1 file changed

+90
-1
lines changed

libs/community/langchain_community/llms/xinference.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
1+
from __future__ import annotations
2+
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Dict,
7+
Generator,
8+
Iterator,
9+
List,
10+
Mapping,
11+
Optional,
12+
Union,
13+
)
214

315
from langchain_core.callbacks import CallbackManagerForLLMRun
416
from langchain_core.language_models.llms import LLM
17+
from langchain_core.outputs import GenerationChunk
518

619
if TYPE_CHECKING:
720
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
@@ -73,6 +86,26 @@ class Xinference(LLM):
7386
generate_config={"max_tokens": 1024, "stream": True},
7487
)
7588
89+
Example:
90+
91+
.. code-block:: python
92+
93+
from langchain_community.llms import Xinference
94+
from langchain.prompts import PromptTemplate
95+
96+
llm = Xinference(
97+
server_url="http://0.0.0.0:9997",
98+
model_uid={model_uid}, # replace model_uid with the model UID return from launching the model
99+
stream=True
100+
)
101+
prompt = PromptTemplate(
102+
input=['country'],
103+
template="Q: where can we visit in the capital of {country}? A:"
104+
)
105+
chain = prompt | llm
106+
chain.stream(input={'country': 'France'})
107+
108+
76109
To view all the supported builtin models, run:
77110
78111
.. code-block:: bash
@@ -216,3 +249,59 @@ def _stream_generate(
216249
token=token, verbose=self.verbose, log_probs=log_probs
217250
)
218251
yield token
252+
253+
def _stream(
254+
self,
255+
prompt: str,
256+
stop: Optional[List[str]] = None,
257+
run_manager: Optional[CallbackManagerForLLMRun] = None,
258+
**kwargs: Any,
259+
) -> Iterator[GenerationChunk]:
260+
generate_config = kwargs.get("generate_config", {})
261+
generate_config = {**self.model_kwargs, **generate_config}
262+
if stop:
263+
generate_config["stop"] = stop
264+
for stream_resp in self._create_generate_stream(prompt, generate_config):
265+
if stream_resp:
266+
chunk = self._stream_response_to_generation_chunk(stream_resp)
267+
if run_manager:
268+
run_manager.on_llm_new_token(
269+
chunk.text,
270+
verbose=self.verbose,
271+
)
272+
yield chunk
273+
274+
def _create_generate_stream(
275+
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
276+
) -> Iterator[str]:
277+
if self.client is None:
278+
raise ValueError("Client is not initialized!")
279+
model = self.client.get_model(self.model_uid)
280+
yield from model.generate(prompt=prompt, generate_config=generate_config)
281+
282+
@staticmethod
283+
def _stream_response_to_generation_chunk(
284+
stream_response: str,
285+
) -> GenerationChunk:
286+
"""Convert a stream response to a generation chunk."""
287+
token = ""
288+
if isinstance(stream_response, dict):
289+
choices = stream_response.get("choices", [])
290+
if choices:
291+
choice = choices[0]
292+
if isinstance(choice, dict):
293+
token = choice.get("text", "")
294+
295+
return GenerationChunk(
296+
text=token,
297+
generation_info=dict(
298+
finish_reason=choice.get("finish_reason", None),
299+
logprobs=choice.get("logprobs", None),
300+
),
301+
)
302+
else:
303+
raise TypeError("choice type error!")
304+
else:
305+
return GenerationChunk(text=token)
306+
else:
307+
raise TypeError("stream_response type error!")

0 commit comments

Comments
 (0)