Skip to content

Commit e3230ac

Browse files
authored
fix generation config (#77)
1 parent ad1a24d commit e3230ac

23 files changed

+177
-75
lines changed

examples/EvoVLM_JP_v1_7B.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from transformers import AutoModelForVision2Seq, AutoProcessor
33
import torch
44
from base_vlm import BaseVLM
5+
from utils import GenerationConfig
56

67

78
class VLM(BaseVLM):
@@ -15,8 +16,15 @@ def __init__(self) -> None:
1516
self.processor = AutoProcessor.from_pretrained(self.model_id)
1617
self.model.to(self.device)
1718

18-
def generate(self, image, text: str, max_new_tokens: int = 256):
19-
text = f"<image>{text}"
19+
def generate(
20+
self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()
21+
):
22+
text = text.replace("<image>", "")
23+
if isinstance(image, list):
24+
text = "<image>" * len(image) + f"{text}"
25+
else:
26+
text = f"<image>{text}"
27+
2028
messages = [
2129
{
2230
"role": "system",
@@ -29,7 +37,7 @@ def generate(self, image, text: str, max_new_tokens: int = 256):
2937
messages, return_tensors="pt"
3038
)
3139
output_ids = self.model.generate(
32-
**inputs.to(self.device), max_new_token=max_new_tokens
40+
**inputs.to(self.device), **gen_kwargs.__dict__
3341
)
3442
output_ids = output_ids[:, inputs.input_ids.shape[1] :]
3543
generated_text = self.processor.batch_decode(

examples/GPT_4o.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from io import BytesIO
44
import base64
55
from base_vlm import BaseVLM
6+
from utils import GenerationConfig
67

78

89
def encode_image_to_base64(image):
@@ -22,7 +23,9 @@ def __init__(self) -> None:
2223
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
2324
)
2425

25-
def generate(self, image, text: str, max_new_tokens: int = 256):
26+
def generate(
27+
self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()
28+
):
2629
if "<image>" in text:
2730
text = text.replace("<image>", "")
2831
message = []
@@ -70,7 +73,11 @@ def generate(self, image, text: str, max_new_tokens: int = 256):
7073
]
7174
try:
7275
response = self.client.chat.completions.create(
73-
model=self.model_id, messages=message, max_tokens=max_new_tokens
76+
model=self.model_id,
77+
messages=message,
78+
max_tokens=gen_kwargs.max_new_tokens,
79+
temperature=gen_kwargs.temperature,
80+
top_p=gen_kwargs.top_p,
7481
)
7582
return response.choices[0].message.content
7683
except Exception as e:

examples/InternVL2_8B.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from transformers import AutoModel, AutoTokenizer
66
from typing import Union
77
from base_vlm import BaseVLM
8+
from utils import GenerationConfig
89

910
IMAGENET_MEAN = (0.485, 0.456, 0.406)
1011
IMAGENET_STD = (0.229, 0.224, 0.225)
@@ -135,7 +136,9 @@ def __init__(self) -> None:
135136
self.model_id, trust_remote_code=True, use_fast=False
136137
)
137138

138-
def generate(self, image, text: str, max_new_tokens: int = 256):
139+
def generate(
140+
self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()
141+
):
139142
text = text.replace("<image>", "")
140143
if "<image>" not in text:
141144
if isinstance(image, list):
@@ -164,14 +167,12 @@ def generate(self, image, text: str, max_new_tokens: int = 256):
164167
load_image(image, max_num=12).to(self.model.device).to(self.model.dtype)
165168
)
166169

167-
generation_config = dict(max_new_tokens=max_new_tokens, do_sample=False)
168-
169170
response = self.model.chat(
170171
self.tokenizer,
171172
pixel_values,
172173
text,
173-
generation_config,
174174
num_patches_list=num_patches_list,
175+
generation_config=gen_kwargs.__dict__,
175176
)
176177
generated_text = response
177178
return generated_text

examples/Llama_3_2_11B_Vision_Instruct.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import MllamaForConditionalGeneration, AutoProcessor
44
from typing import Union
55
from base_vlm import BaseVLM
6+
from utils import GenerationConfig
67

78

89
class VLM(BaseVLM):
@@ -20,7 +21,7 @@ def generate(
2021
self,
2122
images: Union[Image.Image, list[Image.Image]],
2223
text: str,
23-
max_new_tokens: int = 256,
24+
gen_kwargs: GenerationConfig = GenerationConfig(),
2425
):
2526
if "<image>" in text:
2627
text = text.replace("<image>", "")
@@ -41,7 +42,7 @@ def generate(
4142
inputs = self.processor(
4243
images, input_text, add_special_tokens=False, return_tensors="pt"
4344
).to(self.model.device)
44-
output_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
45+
output_ids = self.model.generate(**inputs, **gen_kwargs.__dict__)
4546
generated_ids = [
4647
output_ids[len(input_ids) :]
4748
for input_ids, output_ids in zip(inputs.input_ids, output_ids)

examples/Llama_3_2_11B_Vision_Instruct_Swallow_8B_Merge.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import MllamaForConditionalGeneration, AutoProcessor
44
from typing import Union
55
from base_vlm import BaseVLM
6+
from utils import GenerationConfig
67

78

89
class VLM(BaseVLM):
@@ -20,7 +21,7 @@ def generate(
2021
self,
2122
images: Union[Image.Image, list[Image.Image]],
2223
text: str,
23-
max_new_tokens: int = 256,
24+
gen_kwargs: GenerationConfig = GenerationConfig(),
2425
):
2526
if "<image>" in text:
2627
text = text.replace("<image>", "")
@@ -41,7 +42,7 @@ def generate(
4142
inputs = self.processor(
4243
images, input_text, add_special_tokens=False, return_tensors="pt"
4344
).to(self.model.device)
44-
output_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
45+
output_ids = self.model.generate(**inputs, **gen_kwargs.__dict__)
4546
generated_ids = [
4647
output_ids[len(input_ids) :]
4748
for input_ids, output_ids in zip(inputs.input_ids, output_ids)

examples/Llama_3_EZO_VLM_1.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from mantis.models.mllava.utils import conv_templates
99
from base_vlm import BaseVLM
10+
from utils import GenerationConfig
1011

1112
# 1. Set the system prompt
1213
conv_llama_3_elyza = Conversation(
@@ -33,13 +34,9 @@ def __init__(self) -> None:
3334
)
3435
self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
3536

36-
def generate(self, image, text: str, max_new_tokens: int = 256):
37-
generation_kwargs = {
38-
"max_new_tokens": max_new_tokens,
39-
"num_beams": 1,
40-
"do_sample": False,
41-
"no_repeat_ngram_size": 3,
42-
}
37+
def generate(
38+
self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()
39+
):
4340
if isinstance(image, list):
4441
if "<image>" not in text:
4542
text = "<image> " * len(image) + "\n" + text
@@ -49,7 +46,7 @@ def generate(self, image, text: str, max_new_tokens: int = 256):
4946
text = "<image>\n" + text
5047
images = [image]
5148
response, history = chat_mllava(
52-
text, images, self.model, self.processor, **generation_kwargs
49+
text, images, self.model, self.processor, **gen_kwargs.__dict__
5350
)
5451
return response
5552

examples/Llama_3_EvoVLM_JP_v2.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from mantis.models.mllava.utils import conv_templates
99
from base_vlm import BaseVLM
10+
from utils import GenerationConfig
1011

1112
# 1. Set the system prompt
1213
conv_llama_3_elyza = Conversation(
@@ -33,13 +34,9 @@ def __init__(self) -> None:
3334
)
3435
self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
3536

36-
def generate(self, image, text: str, max_new_tokens: int = 256):
37-
generation_kwargs = {
38-
"max_new_tokens": max_new_tokens,
39-
"num_beams": 1,
40-
"do_sample": False,
41-
"no_repeat_ngram_size": 3,
42-
}
37+
def generate(
38+
self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()
39+
):
4340
if isinstance(image, list):
4441
if "<image>" not in text:
4542
text = "<image> " * len(image) + "\n" + text
@@ -49,7 +46,7 @@ def generate(self, image, text: str, max_new_tokens: int = 256):
4946
text = "<image>\n" + text
5047
images = [image]
5148
response, history = chat_mllava(
52-
text, images, self.model, self.processor, **generation_kwargs
49+
text, images, self.model, self.processor, **gen_kwargs.__dict__
5350
)
5451
return response
5552

examples/Pangea_7B_hf.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from PIL import Image
55
from typing import Union
66
from base_vlm import BaseVLM
7+
from utils import GenerationConfig
78

89

910
class VLM(BaseVLM):
@@ -21,7 +22,7 @@ def generate(
2122
self,
2223
images: Union[Image.Image, list[Image.Image]],
2324
text: str,
24-
max_new_tokens: int = 256,
25+
gen_kwargs: GenerationConfig = GenerationConfig(),
2526
):
2627
if isinstance(images, list):
2728
prompt_template = (
@@ -39,11 +40,7 @@ def generate(
3940
).to("cuda", torch.float16)
4041
output = self.model.generate(
4142
**model_inputs,
42-
max_new_tokens=max_new_tokens,
43-
min_new_tokens=32,
44-
temperature=1.0,
45-
top_p=0.9,
46-
do_sample=True,
43+
**gen_kwargs.__dict__,
4744
)
4845
output = output[0]
4946
result = self.processor.decode(

examples/Pixtral_12B_2409.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import base64
66
from io import BytesIO
77
from base_vlm import BaseVLM
8+
from utils import GenerationConfig
89

910

1011
def image_to_base64(img):
@@ -42,7 +43,7 @@ def generate(
4243
self,
4344
images: Union[Image.Image, list[Image.Image]],
4445
text: str,
45-
max_new_tokens: int = 256,
46+
gen_kwargs: GenerationConfig = GenerationConfig(),
4647
):
4748
if isinstance(images, list):
4849
content = [image_to_content(image) for image in images]
@@ -57,7 +58,11 @@ def generate(
5758
}
5859
]
5960

60-
sampling_params = SamplingParams(max_tokens=max_new_tokens)
61+
sampling_params = SamplingParams(
62+
max_tokens=gen_kwargs.max_new_tokens,
63+
temperature=gen_kwargs.temperature,
64+
top_p=gen_kwargs.top_p,
65+
)
6166
outputs = self.model.chat(
6267
messages,
6368
sampling_params=sampling_params,

examples/Qwen2_VL_7B_Instruct.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import base64
44
from qwen_vl_utils import process_vision_info
55
from base_vlm import BaseVLM
6+
from utils import GenerationConfig
67

78

89
class VLM(BaseVLM):
@@ -18,7 +19,9 @@ def __init__(self) -> None:
1819
self.model_id, min_pixels=min_pixels, max_pixels=max_pixels
1920
)
2021

21-
def generate(self, image, text: str, max_new_tokens: int = 256):
22+
def generate(
23+
self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()
24+
):
2225
if "<image>" in text:
2326
text = text.replace("<image>", "")
2427
message = []
@@ -75,7 +78,7 @@ def generate(self, image, text: str, max_new_tokens: int = 256):
7578
)
7679

7780
inputs = inputs.to(self.model.device)
78-
output_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
81+
output_ids = self.model.generate(**inputs, **gen_kwargs.__dict__)
7982
generated_ids = [
8083
output_ids[len(input_ids) :]
8184
for input_ids, output_ids in zip(inputs.input_ids, output_ids)

0 commit comments

Comments
 (0)