File tree Expand file tree Collapse file tree 3 files changed +16
-8
lines changed Expand file tree Collapse file tree 3 files changed +16
-8
lines changed Original file line number Diff line number Diff line change @@ -39,16 +39,16 @@ class TrainerCliArgs:
39
39
class VllmServeCliArgs :
40
40
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
41
41
42
- tensor_parallel_size : int = field (
43
- default = 1 ,
42
+ tensor_parallel_size : Optional [ int ] = field (
43
+ default = None ,
44
44
metadata = {"help" : "Number of tensor parallel workers to use." },
45
45
)
46
- host : str = field (
47
- default = "0.0.0.0" , # nosec B104
46
+ host : Optional [ str ] = field (
47
+ default = None , # nosec B104
48
48
metadata = {"help" : "Host address to run the server on." },
49
49
)
50
- port : int = field (
51
- default = 8000 ,
50
+ port : Optional [ int ] = field (
51
+ default = None ,
52
52
metadata = {"help" : "Port to run the server on." },
53
53
)
54
54
gpu_memory_utilization : Optional [float ] = field (
Original file line number Diff line number Diff line change @@ -40,8 +40,8 @@ def set_training_args_kwargs(cls, cfg):
40
40
41
41
if trl .use_vllm :
42
42
grpo_args_kwargs ["use_vllm" ] = trl .use_vllm
43
- grpo_args_kwargs ["vllm_server_host" ] = trl .vllm_server_host
44
- grpo_args_kwargs ["vllm_server_port" ] = trl .vllm_server_port
43
+ grpo_args_kwargs ["vllm_server_host" ] = trl .vllm_server_host or trl . vllm . host
44
+ grpo_args_kwargs ["vllm_server_port" ] = trl .vllm_server_port or trl . vllm . port
45
45
if trl .vllm_server_timeout :
46
46
grpo_args_kwargs ["vllm_server_timeout" ] = trl .vllm_server_timeout
47
47
if trl .vllm_guided_decoding_regex :
Original file line number Diff line number Diff line change @@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
36
36
default = None ,
37
37
json_schema_extra = {"description" : "Enable prefix caching for VLLM" },
38
38
)
39
+ host : str | None = Field (
40
+ default = "0.0.0.0" , # nosec B104
41
+ json_schema_extra = {"description" : "Host for the vLLM server to start on" },
42
+ )
43
+ port : int | None = Field (
44
+ default = 8000 ,
45
+ json_schema_extra = {"description" : "Port of the vLLM server to start on" },
46
+ )
You can’t perform that action at this time.
0 commit comments