Skip to content

Commit 8a527fe

Browse files
hjh0119Jintao-Huang
authored andcommitted
[bugfix] fix import issues (#5407)
1 parent 242d26e commit 8a527fe

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

swift/llm/infer/infer_engine/grpo_vllm_engine.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tqdm.asyncio import tqdm_asyncio
99

1010
from swift.llm import InferRequest, RolloutInferRequest, Template, VllmEngine
11+
from swift.llm.infer.protocol import MultiModalRequestMixin
1112
from swift.plugin import Metric, multi_turns
1213
from swift.plugin.context_manager import ContextManager, context_managers
1314
from swift.plugin.env import Env, envs
@@ -295,13 +296,16 @@ async def _multi_turn_sampling_controller(self, infer_request: RolloutInferReque
295296
if should_stop:
296297
result_choice.messages = messages
297298
info_dict['num_turns'] = current_turn
298-
for key, value in info_dict.items():
299+
for key, values in info_dict.items():
299300
if key in ['images', 'audios', 'videos']:
300-
value = MultiModalRequestMixin.to_base64(value)
301+
if not isinstance(values, list):
302+
values = [values]
303+
for i, value in enumerate(values):
304+
values[i] = MultiModalRequestMixin.to_base64(value)
301305
if hasattr(result_choice, key):
302-
setattr(result_choice, key, value)
306+
setattr(result_choice, key, values)
303307
else:
304-
result_choice.multi_turn_infos[key] = value
308+
result_choice.multi_turn_infos[key] = values
305309
return result
306310

307311
ret = self.multi_turn_scheduler.step(current_request, result_choice, current_turn)

swift/plugin/multi_turn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,12 @@ def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseCho
6767

6868

6969
class MathTipsMultiTurnScheduler(MultiTurnScheduler):
70-
from .orm import MathAccuracy
71-
tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.'
72-
acc_func = MathAccuracy()
70+
71+
def __init__(self, max_turns: Optional[int] = None, *args, **kwargs):
72+
super().__init__(max_turns, *args, **kwargs)
73+
from .orm import MathAccuracy
74+
self.tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.'
75+
self.acc_func = MathAccuracy()
7376

7477
def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
7578
current_turn: int) -> bool:

0 commit comments

Comments
 (0)