|
5 | 5 | from transformers import AutoModel, AutoTokenizer |
6 | 6 | from typing import Union |
7 | 7 | from base_vlm import BaseVLM |
| 8 | +from utils import GenerationConfig |
8 | 9 |
|
9 | 10 | IMAGENET_MEAN = (0.485, 0.456, 0.406) |
10 | 11 | IMAGENET_STD = (0.229, 0.224, 0.225) |
@@ -135,7 +136,9 @@ def __init__(self) -> None: |
135 | 136 | self.model_id, trust_remote_code=True, use_fast=False |
136 | 137 | ) |
137 | 138 |
|
138 | | - def generate(self, image, text: str, max_new_tokens: int = 256): |
| 139 | + def generate( |
| 140 | + self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig() |
| 141 | + ): |
139 | 142 | text = text.replace("<image>", "") |
140 | 143 | if "<image>" not in text: |
141 | 144 | if isinstance(image, list): |
@@ -164,14 +167,12 @@ def generate(self, image, text: str, max_new_tokens: int = 256): |
164 | 167 | load_image(image, max_num=12).to(self.model.device).to(self.model.dtype) |
165 | 168 | ) |
166 | 169 |
|
167 | | - generation_config = dict(max_new_tokens=max_new_tokens, do_sample=False) |
168 | | - |
169 | 170 | response = self.model.chat( |
170 | 171 | self.tokenizer, |
171 | 172 | pixel_values, |
172 | 173 | text, |
173 | | - generation_config, |
174 | 174 | num_patches_list=num_patches_list, |
| 175 | + generation_config=gen_kwargs.__dict__, |
175 | 176 | ) |
176 | 177 | generated_text = response |
177 | 178 | return generated_text |
|
0 commit comments