Skip to content

Commit 23559d4

Browse files
authored
fix max_length error print (#2960)
1 parent db01dea commit 23559d4

File tree

5 files changed

+19
-9
lines changed

5 files changed

+19
-9
lines changed

examples/export/quantize/mllm/awq.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ swift export \
88
--max_length 2048 \
99
--quant_method awq \
1010
--quant_bits 4 \
11-
--output_dir Qwen/Qwen2-VL-2B-Instruct-AWQ
11+
--output_dir Qwen2-VL-2B-Instruct-AWQ

examples/export/quantize/mllm/gptq.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ swift export \
1515
--max_length 2048 \
1616
--quant_method gptq \
1717
--quant_bits 4 \
18-
--output_dir Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
18+
--output_dir Qwen2-VL-2B-Instruct-GPTQ-Int4

swift/llm/dataset/preprocessor/core.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def _fix_streaming_keys(row):
143143
new_k = k[len('__@'):]
144144
row[new_k] = row.pop(k)
145145

146-
def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
146+
def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool,
147+
ignore_max_length_error: bool) -> Dict[str, Any]:
148+
from ...template import MaxLengthError
147149
batched_row = dict(batched_row)
148150
assert len(batched_row) > 0
149151
self._fix_streaming_keys(batched_row)
@@ -162,13 +164,15 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di
162164
self._check_messages(r)
163165
self._check_rejected_response(r)
164166
self._cast_images(r)
165-
except Exception:
167+
except Exception as e:
166168
if strict:
167169
logger.warning('To avoid errors, you can pass `strict=False`.')
168170
raise
169-
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
171+
if isinstance(e, MaxLengthError) and ignore_max_length_error:
172+
pass
173+
elif self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
170174
import traceback
171-
print(traceback.format_exc())
175+
logger.info(traceback.format_exc())
172176
logger.error('👆👆👆There are errors in the dataset, the data will be deleted')
173177
self._traceback_counter += 1
174178
row = []
@@ -256,15 +260,21 @@ def __call__(
256260
dataset = self.prepare_dataset(dataset)
257261
dataset = self._cast_pil_image(dataset)
258262
map_kwargs = {}
263+
ignore_max_length_error = False
259264
if isinstance(dataset, HfDataset):
260265
map_kwargs['num_proc'] = num_proc
266+
if num_proc > 1:
267+
ignore_max_length_error = True
261268
with self._patch_arrow_writer():
262269
try:
263270
dataset_mapped = dataset.map(
264271
self.batched_preprocess,
265272
batched=True,
266273
batch_size=batch_size,
267-
fn_kwargs={'strict': strict},
274+
fn_kwargs={
275+
'strict': strict,
276+
'ignore_max_length_error': ignore_max_length_error
277+
},
268278
remove_columns=list(dataset.features.keys()),
269279
**map_kwargs)
270280
except NotImplementedError:

swift/llm/dataset/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
183183
raise
184184
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
185185
import traceback
186-
print(traceback.format_exc())
186+
logger.info(traceback.format_exc())
187187
logger.error('👆👆👆There are errors in the template.encode, '
188188
'and another piece of data will be randomly selected.')
189189
self._traceback_counter += 1

swift/llm/infer/deploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def pre_infer_hook(kwargs):
159159
res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs)
160160
except Exception as e:
161161
import traceback
162-
print(traceback.format_exc())
162+
logger.info(traceback.format_exc())
163163
return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e))
164164
if request_config.stream:
165165

0 commit comments

Comments
 (0)