Skip to content

Commit 32e335d

Browse files
authored
fix missing host/port for vllm (axolotl-ai-cloud#2543)
* fix missing host/port for vllm * set tensor parallel size so it doesn't always default to cli override
1 parent 7651550 commit 32e335d

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

src/axolotl/cli/args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ class TrainerCliArgs:
3939
class VllmServeCliArgs:
4040
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
4141

42-
tensor_parallel_size: int = field(
43-
default=1,
42+
tensor_parallel_size: Optional[int] = field(
43+
default=None,
4444
metadata={"help": "Number of tensor parallel workers to use."},
4545
)
46-
host: str = field(
47-
default="0.0.0.0", # nosec B104
46+
host: Optional[str] = field(
47+
default=None, # nosec B104
4848
metadata={"help": "Host address to run the server on."},
4949
)
50-
port: int = field(
51-
default=8000,
50+
port: Optional[int] = field(
51+
default=None,
5252
metadata={"help": "Port to run the server on."},
5353
)
5454
gpu_memory_utilization: Optional[float] = field(

src/axolotl/core/trainers/grpo/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def set_training_args_kwargs(cls, cfg):
4040

4141
if trl.use_vllm:
4242
grpo_args_kwargs["use_vllm"] = trl.use_vllm
43-
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
44-
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
43+
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
44+
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
4545
if trl.vllm_server_timeout:
4646
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
4747
if trl.vllm_guided_decoding_regex:

src/axolotl/utils/schemas/vllm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
3636
default=None,
3737
json_schema_extra={"description": "Enable prefix caching for VLLM"},
3838
)
39+
host: str | None = Field(
40+
default="0.0.0.0", # nosec B104
41+
json_schema_extra={"description": "Host for the vLLM server to start on"},
42+
)
43+
port: int | None = Field(
44+
default=8000,
45+
json_schema_extra={"description": "Port of the vLLM server to start on"},
46+
)

0 commit comments

Comments
 (0)