Skip to content
25 changes: 24 additions & 1 deletion mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team
import torch
import socket
from typing import Union, List
from enum import Enum
from pydantic import BaseModel, validator, root_validator
Expand Down Expand Up @@ -39,7 +40,8 @@ def __repr__(self):

class MIIConfig(BaseModel):
tensor_parallel: int = 1
port_number: int = 50050
port_number: int = None
_DEFAULT_PORT = 50050
dtype: DtypeEnum = torch.float32
meta_tensor: bool = False
load_with_sys_mem: bool = False
Expand All @@ -58,6 +60,27 @@ class MIIConfig(BaseModel):
hostfile: str = DLTS_HOSTFILE
trust_remote_code: bool = False

@validator('port_number')
def assign_port(cls, field_value, values):
def is_port_in_use(port_number: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port_number)) == 0

# If the user sets a port, make sure we use that port
if field_value is not None:
assert not is_port_in_use(field_value), f"Port number {field_value} already in use."
# Otherwise try the default value
else:
field_value = cls._DEFAULT_PORT
# If the default is in use, select a random port
if is_port_in_use(field_value):
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.bind(("", 0))
field_value = tcp.getsockname()[1]
tcp.close()

return field_value

@validator("deploy_rank")
def deploy_valid(cls, field_value, values):
if "tensor_parallel" not in values:
Expand Down