diff --git a/mii/config.py b/mii/config.py index 2714cb40..a6233fee 100644 --- a/mii/config.py +++ b/mii/config.py @@ -3,6 +3,7 @@ # DeepSpeed Team import torch +import socket from typing import Union, List from enum import Enum from pydantic import BaseModel, validator, root_validator @@ -39,7 +40,8 @@ def __repr__(self): class MIIConfig(BaseModel): tensor_parallel: int = 1 - port_number: int = 50050 + port_number: int = None + _DEFAULT_PORT = 50050 dtype: DtypeEnum = torch.float32 meta_tensor: bool = False load_with_sys_mem: bool = False @@ -58,6 +60,27 @@ class MIIConfig(BaseModel): hostfile: str = DLTS_HOSTFILE trust_remote_code: bool = False + @validator('port_number') + def assign_port(cls, field_value, values): + def is_port_in_use(port_number: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', port_number)) == 0 + + # If the user sets a port, make sure we use that port + if field_value is not None: + assert not is_port_in_use(field_value), f"Port number {field_value} already in use." + # Otherwise try the default value + else: + field_value = cls._DEFAULT_PORT + # If the default is in use, select a random port + if is_port_in_use(field_value): + tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tcp.bind(("", 0)) + field_value = tcp.getsockname()[1] + tcp.close() + + return field_value + @validator("deploy_rank") def deploy_valid(cls, field_value, values): if "tensor_parallel" not in values: