Skip to content
42 changes: 41 additions & 1 deletion mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team
import torch
import socket
from typing import Union, List
from enum import Enum
from pydantic import BaseModel, validator, root_validator
Expand Down Expand Up @@ -39,7 +40,7 @@ def __repr__(self):

class MIIConfig(BaseModel):
tensor_parallel: int = 1
port_number: int = 50050
port_number: int = None
dtype: DtypeEnum = torch.float32
meta_tensor: bool = False
load_with_sys_mem: bool = False
Expand All @@ -57,6 +58,45 @@ class MIIConfig(BaseModel):
replica_num: int = 1
hostfile: str = DLTS_HOSTFILE
trust_remote_code: bool = False

def __is_port_in_use(port_number: int) -> bool:
"""
Checks if a port_number is in use

Args:
port_number (int): port_number to check

Returns:
bool: True if port_number is in use else False
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port_number)) == 0

@validator('port_number')
def assign_port(port_number: int = None) -> int:
"""
Starts a socket connection to grab a free port (Involves a race
condition but will do for now)
Args:
port_number (int): Port to start the socket connection (default: None)
Returns:
int: Free port number
"""
DEFAULT_PORT = 50050
# if port is None set the default 50050 and default port is not in use return it
if port_number is None:
port_number = DEFAULT_PORT

# if the defined port is in use find a free port
if MIIConfig.__is_port_in_use(port_number):
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.bind(("", 0))
_, port_number = tcp.getsockname()
tcp.close()

MIIConfig.port_number = port_number

return port_number

@validator("deploy_rank")
def deploy_valid(cls, field_value, values):
Expand Down