@@ -228,6 +228,7 @@ def _check_grpo(self):
228228 'Please update it by running: pip install -U trl' )
229229
230230 if self .use_liger_kernel :
231+ assert trl_version >= version .parse ('0.18' )
231232 if self .delta is not None :
232233 raise ValueError ('Liger loss does not support two-sided GRPO loss yet.' )
233234 from trl .import_utils import is_liger_kernel_available
@@ -252,22 +253,22 @@ def _check_grpo(self):
252253 if self .generation_batch_size or self .steps_per_generation :
253254 from trl .trainer .grpo_config import GRPOConfig
254255 assert 'generation_batch_size' in GRPOConfig .__dict__ , (
255- 'generation_batch_size or steps_per_generation needs trl >= 0.18.dev , '
256- 'please install trl from source `pip install git+https://github.com/huggingface/ trl.git ' )
256+ 'generation_batch_size or steps_per_generation needs trl >= 0.18, '
257+ 'please install trl `pip install trl>=0.18 ' )
257258
258259 def _external_vllm_warning (self ):
259260 if self .rlhf_type != 'grpo' or not self .vllm_server_host :
260261 return
261262
262- if self .vllm_device != 'auto' :
263+ if self .vllm_device is not None :
263264 logger .warning ("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. " ,
264265 self .vllm_device )
265266
266267 if self .vllm_max_model_len is not None :
267268 logger .warning (
268269 "Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. "
269270 'Please specify it when launching the inference service: '
270- '`swift deploy --max_model_len <value>`' , self .vllm_max_model_len )
271+ '`swift rollout --max_model_len <value>`' , self .vllm_max_model_len )
271272
272273 def _deprecated_warning (self ):
273274 if self .rlhf_type != 'grpo' :
0 commit comments