Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions trl/experimental/gold/gold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ class GOLDConfig(SFTConfig):
Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).
vllm_structured_outputs_regex (`str`, *optional*):
Regex for vLLM structured outputs for the student model.
vllm_server_base_url (`str`, *optional*):
Base URL for the vLLM server (e.g., `"http://localhost:8001"`). If provided, `vllm_server_host` and
`vllm_server_port` are ignored.
vllm_group_port (`int`, *optional*, defaults to `51216`):
Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need
to change it.
vllm_max_model_length (`int`, *optional*):
Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the
model's maximum context length.
vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
Model implementation backend to use in vLLM. Use `"vllm"` (default) or `"transformers"`.
vllm_sync_frequency (`int`, *optional*, defaults to `1`):
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after
every step.
Expand Down Expand Up @@ -296,6 +307,12 @@ class GOLDConfig(SFTConfig):
"help": 'Mode for vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).'
},
)
vllm_server_base_url: str | None = field(
default=None,
metadata={
"help": 'Base URL for the vLLM server (e.g., "http://localhost:8001"). If provided, vllm_server_host and vllm_server_port are ignored.'
},
)
vllm_server_host: str = field(
default="0.0.0.0",
metadata={"help": 'Host of the vLLM server when `vllm_mode="server"`.'},
Expand All @@ -308,6 +325,10 @@ class GOLDConfig(SFTConfig):
default=240.0,
metadata={"help": 'Timeout (in seconds) for connecting to the vLLM server when `vllm_mode="server"`.'},
)
vllm_group_port: int = field(
default=51216,
metadata={"help": "Port for the vLLM weight-update group (NCCL communicator)."},
)
vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
Expand All @@ -318,6 +339,16 @@ class GOLDConfig(SFTConfig):
default=1,
metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'},
)
vllm_max_model_length: int | None = field(
default=None,
metadata={
"help": 'Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the model\'s maximum context length.'
},
)
vllm_model_impl: str = field(
default="vllm",
metadata={"help": 'Model implementation backend to use in vLLM. Use "vllm" (default) or "transformers".'},
)
vllm_structured_outputs_regex: str | None = field(
default=None,
metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."},
Expand Down
Loading
Loading