@@ -210,10 +210,10 @@ def _init_grpo(self):
210
210
logger .info (f'Auto-configured soft_max_length = max_completion_length { self .max_completion_length } ' )
211
211
if self .use_vllm :
212
212
# set vllm mode
213
- if self .vllm_server_host is not None :
213
+ if self .vllm_server_host is not None or self . vllm_server_base_url is not None :
214
214
if self .vllm_mode != 'server' :
215
215
self .vllm_mode = 'server'
216
- logger .warning ('set vllm_mode to `server` since vllm_server_host is provided' )
216
+ logger .warning ('set vllm_mode to `server` since vllm server host/base_url is provided' )
217
217
else :
218
218
if self .vllm_mode != 'colocate' :
219
219
self .vllm_mode = 'colocate'
@@ -250,7 +250,7 @@ def _init_rm(self):
250
250
self .num_labels = 1
251
251
252
252
def _init_external_vllm (self ):
253
- if self .rlhf_type != 'grpo' or self .vllm_server_host is None :
253
+ if self .rlhf_type != 'grpo' or ( self .vllm_server_host is None and self . vllm_server_base_url is None ) :
254
254
return
255
255
from swift .trainers .rlhf_trainer .vllm_client import VLLMClient
256
256
if is_master ():
@@ -310,7 +310,7 @@ def _check_grpo(self):
310
310
assert is_liger_kernel_available (), (
311
311
'Please install/update liger-kernel by running: pip install -U liger-kernel' )
312
312
if self .vllm_mode == 'server' :
313
- assert not self .use_vllm or self .vllm_server_host is not None
313
+ assert not self .use_vllm or self .vllm_server_host is not None or self . vllm_server_base_url is not None
314
314
315
315
if self .async_generate :
316
316
assert self .vllm_mode == 'server' , 'async generate require vllm_mode == server, '
0 commit comments