Skip to content

Commit a61c319

Browse files
authored
[rollout] fix request from dict (#4826)
1 parent 56e84a2 commit a61c319

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

swift/llm/infer/infer_engine/grpo_vllm_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,8 @@ async def async_infer(self,
135135
async def _infer_async_single(infer_request: Union[RolloutInferRequest, Dict[str, Any]],
136136
request_config: Optional[RequestConfig] = None,
137137
**kwargs):
138-
# discard origin last turn reponse in first turn
139138
if isinstance(infer_request, Dict):
140-
infer_request = RolloutInferRequest(
141-
messages=infer_request['messages'], data_dict=infer_request.get('data_dict', None))
139+
infer_request = RolloutInferRequest(**infer_request)
142140
current_request = infer_request
143141
current_turn = 1
144142
while True:
@@ -169,6 +167,7 @@ async def _infer_async_single(infer_request: Union[RolloutInferRequest, Dict[str
169167
return result
170168

171169
current_request = self.multi_turn_scheduler.step(current_request, result_choice, current_turn)
170+
assert isinstance(current_request, RolloutInferRequest)
172171
if current_request.messages[-1]['role'] == 'assistant':
173172
# NOTE: engine will discard last response during inference
174173
# https://github.com/modelscope/ms-swift/blob/v3.5.1/swift/llm/template/base.py#L416-L419

0 commit comments

Comments
 (0)