Skip to content

Commit 2001aee

Browse files
authored
Adding streaming to llava (#236)
1 parent 4133898 commit 2001aee

File tree

1 file changed

+38
-22
lines changed

1 file changed

+38
-22
lines changed

llava/llava-v1.6-34b/model/model.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
from io import BytesIO
3+
from threading import Thread
34

45
import requests
56
import torch
@@ -14,6 +15,7 @@
1415
# original LLaVA repository: https://github.com/haotian-liu/LLaVA/
1516
from llava.model.builder import load_pretrained_model
1617
from PIL import Image
18+
from transformers import TextIteratorStreamer
1719

1820
model_path = "liuhaotian/llava-v1.6-34b"
1921
DEFAULT_IMAGE_TOKEN = "<image>"
@@ -46,18 +48,18 @@ def load(self):
4648

4749
# inference code from: https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/predict.py#L87
4850
def predict(self, model_input):
49-
query = model_input["query"]
50-
image = model_input["image"]
51+
query = model_input.get("query")
52+
image = model_input.get("image")
53+
top_p = model_input.get("top_p", 1.0)
54+
temperature = model_input.get("temperature", 0.2)
55+
max_tokens = model_input.get("max_tokens", 512)
56+
stream = model_input.get("stream", True)
5157

5258
if image[:5] == "https":
5359
image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
5460
else:
5561
image = b64_to_pil(image)
5662

57-
top_p = model_input.get("top_p", 1.0)
58-
temperature = model_input.get("temperature", 0.2)
59-
max_tokens = model_input.get("max_tokens", 1000)
60-
6163
# Run model inference here
6264
conv_mode = "llava_v1"
6365
conv = conv_templates[conv_mode].copy()
@@ -88,21 +90,35 @@ def predict(self, model_input):
8890
keywords, self.tokenizer, input_ids
8991
)
9092

91-
with torch.inference_mode():
92-
output = self.model.generate(
93-
inputs=input_ids,
94-
images=image_tensor,
95-
do_sample=True,
96-
temperature=temperature,
97-
top_p=top_p,
98-
max_new_tokens=max_tokens,
99-
use_cache=True,
100-
stopping_criteria=[stopping_criteria],
93+
generate_args = {
94+
"inputs": input_ids,
95+
"images": image_tensor,
96+
"do_sample": True,
97+
"temperature": temperature,
98+
"top_p": top_p,
99+
"max_new_tokens": max_tokens,
100+
"use_cache": True,
101+
"stopping_criteria": [stopping_criteria],
102+
}
103+
104+
def generator():
105+
streamer = TextIteratorStreamer(
106+
self.tokenizer, skip_prompt=True, timeout=20.0
101107
)
108+
thread = Thread(
109+
target=self.model.generate,
110+
kwargs={**generate_args, "streamer": streamer},
111+
)
112+
thread.start()
113+
for text in streamer:
114+
yield text
115+
thread.join()
102116

103-
output = self.tokenizer.decode(
104-
output[0][len(input_ids[0]) :], skip_special_tokens=True
105-
)
106-
print(output)
107-
108-
return {"result": output}
117+
with torch.inference_mode():
118+
if stream:
119+
return generator()
120+
else:
121+
full_text = ""
122+
for text in generator():
123+
full_text += text
124+
return {"result": full_text}

0 commit comments

Comments
 (0)