|
1 | 1 | import base64
|
2 | 2 | from io import BytesIO
|
| 3 | +from threading import Thread |
3 | 4 |
|
4 | 5 | import requests
|
5 | 6 | import torch
|
|
14 | 15 | # original LLaVA repository: https://github.com/haotian-liu/LLaVA/
|
15 | 16 | from llava.model.builder import load_pretrained_model
|
16 | 17 | from PIL import Image
|
| 18 | +from transformers import TextIteratorStreamer |
17 | 19 |
|
18 | 20 | model_path = "liuhaotian/llava-v1.6-34b"
|
19 | 21 | DEFAULT_IMAGE_TOKEN = "<image>"
|
@@ -46,18 +48,18 @@ def load(self):
|
46 | 48 |
|
47 | 49 | # inference code from: https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/predict.py#L87
|
48 | 50 | def predict(self, model_input):
|
49 |
| - query = model_input["query"] |
50 |
| - image = model_input["image"] |
| 51 | + query = model_input.get("query") |
| 52 | + image = model_input.get("image") |
| 53 | + top_p = model_input.get("top_p", 1.0) |
| 54 | + temperature = model_input.get("temperature", 0.2) |
| 55 | + max_tokens = model_input.get("max_tokens", 512) |
| 56 | + stream = model_input.get("stream", True) |
51 | 57 |
|
52 | 58 | if image[:5] == "https":
|
53 | 59 | image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
|
54 | 60 | else:
|
55 | 61 | image = b64_to_pil(image)
|
56 | 62 |
|
57 |
| - top_p = model_input.get("top_p", 1.0) |
58 |
| - temperature = model_input.get("temperature", 0.2) |
59 |
| - max_tokens = model_input.get("max_tokens", 1000) |
60 |
| - |
61 | 63 | # Run model inference here
|
62 | 64 | conv_mode = "llava_v1"
|
63 | 65 | conv = conv_templates[conv_mode].copy()
|
@@ -88,21 +90,35 @@ def predict(self, model_input):
|
88 | 90 | keywords, self.tokenizer, input_ids
|
89 | 91 | )
|
90 | 92 |
|
91 |
| - with torch.inference_mode(): |
92 |
| - output = self.model.generate( |
93 |
| - inputs=input_ids, |
94 |
| - images=image_tensor, |
95 |
| - do_sample=True, |
96 |
| - temperature=temperature, |
97 |
| - top_p=top_p, |
98 |
| - max_new_tokens=max_tokens, |
99 |
| - use_cache=True, |
100 |
| - stopping_criteria=[stopping_criteria], |
| 93 | + generate_args = { |
| 94 | + "inputs": input_ids, |
| 95 | + "images": image_tensor, |
| 96 | + "do_sample": True, |
| 97 | + "temperature": temperature, |
| 98 | + "top_p": top_p, |
| 99 | + "max_new_tokens": max_tokens, |
| 100 | + "use_cache": True, |
| 101 | + "stopping_criteria": [stopping_criteria], |
| 102 | + } |
| 103 | + |
| 104 | + def generator(): |
| 105 | + streamer = TextIteratorStreamer( |
| 106 | + self.tokenizer, skip_prompt=True, timeout=20.0 |
101 | 107 | )
|
| 108 | + thread = Thread( |
| 109 | + target=self.model.generate, |
| 110 | + kwargs={**generate_args, "streamer": streamer}, |
| 111 | + ) |
| 112 | + thread.start() |
| 113 | + for text in streamer: |
| 114 | + yield text |
| 115 | + thread.join() |
102 | 116 |
|
103 |
| - output = self.tokenizer.decode( |
104 |
| - output[0][len(input_ids[0]) :], skip_special_tokens=True |
105 |
| - ) |
106 |
| - print(output) |
107 |
| - |
108 |
| - return {"result": output} |
| 117 | + with torch.inference_mode(): |
| 118 | + if stream: |
| 119 | + return generator() |
| 120 | + else: |
| 121 | + full_text = "" |
| 122 | + for text in generator(): |
| 123 | + full_text += text |
| 124 | + return {"result": full_text} |
0 commit comments