Skip to content

Commit 12f1f3d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b777e59 commit 12f1f3d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

auto_round/compressors/mllm/processor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
"""
3030
import os
3131
from datetime import datetime, timedelta
32+
from types import SimpleNamespace
3233

3334
import torch
3435
from transformers.data.data_collator import default_data_collator
35-
from types import SimpleNamespace
3636

3737
from .utils import fetch_image
3838

@@ -165,7 +165,7 @@ def get_input(
165165
max_length=None,
166166
truncation=False,
167167
truncation_strategy="text",
168-
**kwargs
168+
**kwargs,
169169
):
170170

171171
if isinstance(text, list):
@@ -196,6 +196,7 @@ def squeeze_result(ret):
196196
ret[key] = ret[key][0]
197197
return ret
198198

199+
199200
@register_processor("longcat_next")
200201
class LongCatNextProcessor(BasicProcessor):
201202
"""Processor for meituan-longcat/LongCat-Next multimodal models.
@@ -210,6 +211,7 @@ class LongCatNextProcessor(BasicProcessor):
210211
IMAGE_TOKEN = "<image>"
211212
LONGCAT_IMG_START = "<longcat_img_start>"
212213
LONGCAT_IMG_END = "<longcat_img_end>"
214+
213215
def __init__(self):
214216
super().__init__()
215217
from transformers.generation.configuration_utils import GenerationConfig
@@ -228,8 +230,8 @@ def __init__(self):
228230
"cfg_scale": 3.0,
229231
"token_h": 37,
230232
"token_w": 37,
231-
"anyres_prefix": "<longcat_img_token_size>{h} {w}</longcat_img_token_size>"
232-
}
233+
"anyres_prefix": "<longcat_img_token_size>{h} {w}</longcat_img_token_size>",
234+
},
233235
}
234236
self.visual_generation_config = GenerationConfig(**visual_config)
235237

@@ -261,9 +263,7 @@ def get_input(self, text, images, squeeze=True, max_length=None, truncation=Fals
261263
)
262264
messages.append(msg)
263265

264-
text_input = self.tokenizer.apply_chat_template(
265-
messages, tokenize=False, add_generation_prompt=True
266-
)
266+
text_input = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
267267
else:
268268
# Plain string input
269269
if max_length is not None:
@@ -546,7 +546,7 @@ def get_input(
546546
max_length=None,
547547
truncation=False,
548548
truncation_strategy="text",
549-
**kwargs
549+
**kwargs,
550550
):
551551
from mistral_common.protocol.instruct.request import ChatCompletionRequest # pylint: disable=E0401
552552

0 commit comments

Comments
 (0)