@@ -213,10 +213,10 @@ def _init_grpo(self):
213
213
logger .info (f'Auto-configured soft_max_length = max_completion_length { self .max_completion_length } ' )
214
214
if self .use_vllm :
215
215
# set vllm mode
216
- if self .vllm_server_host is not None :
216
+ if self .vllm_server_host is not None or self . vllm_server_base_url is not None :
217
217
if self .vllm_mode != 'server' :
218
218
self .vllm_mode = 'server'
219
- logger .warning ('set vllm_mode to `server` since vllm_server_host is provided' )
219
+ logger .warning ('set vllm_mode to `server` since vllm server host/base_url is provided' )
220
220
else :
221
221
if self .vllm_mode != 'colocate' :
222
222
self .vllm_mode = 'colocate'
@@ -253,7 +253,7 @@ def _init_rm(self):
253
253
self .num_labels = 1
254
254
255
255
def _init_external_vllm (self ):
256
- if self .rlhf_type != 'grpo' or self .vllm_server_host is None :
256
+ if self .rlhf_type != 'grpo' or ( self .vllm_server_host is None and self . vllm_server_base_url is None ) :
257
257
return
258
258
from swift .trainers .rlhf_trainer .vllm_client import VLLMClient
259
259
if is_master ():
@@ -313,7 +313,7 @@ def _check_grpo(self):
313
313
assert is_liger_kernel_available (), (
314
314
'Please install/update liger-kernel by running: pip install -U liger-kernel' )
315
315
if self .vllm_mode == 'server' :
316
- assert not self .use_vllm or self .vllm_server_host is not None
316
+ assert not self .use_vllm or self .vllm_server_host is not None or self . vllm_server_base_url is not None
317
317
318
318
if self .async_generate :
319
319
assert self .vllm_mode == 'server' , 'async generate require vllm_mode == server, '
0 commit comments