Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 132 additions & 51 deletions vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Class for generating SLURM scripts to run vLLM servers."""

from datetime import datetime
from pathlib import Path
from typing import Any
Expand All @@ -6,7 +8,24 @@


class SlurmScriptGenerator:
"""A class to generate SLURM scripts for running vLLM servers.

This class handles the generation of SLURM scripts for both single-node and
multi-node configurations, supporting different virtualization environments
(venv or singularity).

Args:
params (dict[str, Any]): Configuration parameters for the SLURM script
src_dir (str): Source directory path containing necessary scripts
"""

def __init__(self, params: dict[str, Any], src_dir: str):
"""Initialize the SlurmScriptGenerator with configuration parameters.

Args:
params (dict[str, Any]): Configuration parameters for the SLURM script
src_dir (str): Source directory path containing necessary scripts
"""
self.params = params
self.src_dir = src_dir
self.is_multinode = int(self.params["num_nodes"]) > 1
Expand All @@ -16,13 +35,25 @@ def __init__(self, params: dict[str, Any], src_dir: str):
self.task = VLLM_TASK_MAP[self.params["model_type"]]

def _generate_script_content(self) -> str:
"""Generate the complete SLURM script content.

Returns
-------
str: The complete SLURM script as a string
"""
preamble = self._generate_preamble()
server = self._generate_server_script()
launcher = self._generate_launcher()
args = self._generate_shared_args()
return preamble + server + launcher + args

def _generate_preamble(self) -> str:
"""Generate the SLURM script preamble with job specifications.

Returns
-------
str: SLURM preamble containing resource requests and job parameters
"""
base = [
"#!/bin/bash",
"#SBATCH --cpus-per-task=16",
Expand All @@ -37,6 +68,15 @@ def _generate_preamble(self) -> str:
return "\n".join(base)

def _generate_shared_args(self) -> str:
"""Generate the command-line arguments for the vLLM server.

Handles both single-node and multi-node configurations, setting appropriate
parallel processing parameters based on the configuration.

Returns
-------
str: Command-line arguments for the vLLM server
"""
if self.is_multinode and not self.params["pipeline_parallelism"]:
tensor_parallel_size = (
self.params["num_nodes"] * self.params["gpus_per_node"]
Expand Down Expand Up @@ -77,88 +117,121 @@ def _generate_shared_args(self) -> str:
return "\n".join(args)

def _generate_server_script(self) -> str:
"""Generate the server initialization script.

Creates the script section that handles server setup, including Ray
initialization for multi-node setups and port configuration.

Returns
-------
str: Server initialization script content
"""
server_script = [""]
if self.params["venv"] == "singularity":
server_script.append("""module load singularity-ce/3.8.2
singularity exec $SINGULARITY_IMAGE ray stop
""")
server_script.append(
"module load singularity-ce/3.8.2\n"
"singularity exec $SINGULARITY_IMAGE ray stop\n"
)
server_script.append(f"source {self.src_dir}/find_port.sh\n")
server_script.append(
self._generate_multinode_server_script()
if self.is_multinode
else self._generate_single_node_server_script()
)
server_script.append(f"""json_path="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
jq --arg server_addr "$server_address" \\
'. + {{"server_address": $server_addr}}' \\
"$json_path" > temp.json \\
&& mv temp.json "$json_path"

""")
server_script.append(
f'json_path="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"\n'
'jq --arg server_addr "$server_address" \\\n'
" '. + {{\"server_address\": $server_addr}}' \\\n"
' "$json_path" > temp.json \\\n'
' && mv temp.json "$json_path"\n\n'
)
return "\n".join(server_script)

def _generate_single_node_server_script(self) -> str:
return """hostname=${SLURMD_NODENAME}
vllm_port_number=$(find_available_port ${hostname} 8080 65535)
"""Generate the server script for single-node deployment.

server_address="http://${hostname}:${vllm_port_number}/v1"
echo "Server address: $server_address"
"""
Returns
-------
str: Script content for single-node server setup
"""
return (
"hostname=${SLURMD_NODENAME}\n"
"vllm_port_number=$(find_available_port ${hostname} 8080 65535)\n\n"
'server_address="http://${hostname}:${vllm_port_number}/v1"\n'
'echo "Server address: $server_address"\n'
)

def _generate_multinode_server_script(self) -> str:
server_script = []
server_script.append("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
"""Generate the server script for multi-node deployment.

head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
Creates a script that initializes Ray cluster with head and worker nodes,
configuring networking and GPU resources appropriately.

head_node_port=$(find_available_port $head_node_ip 8080 65535)

ip_head=$head_node_ip:$head_node_port
export ip_head
echo "IP Head: $ip_head"

echo "Starting HEAD at $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" \\""")
Returns
-------
str: Script content for multi-node server setup
"""
server_script = []
server_script.append(
'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")\n'
"nodes_array=($nodes)\n\n"
"head_node=${nodes_array[0]}\n"
'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)\n\n'
"head_node_port=$(find_available_port $head_node_ip 8080 65535)\n\n"
"ip_head=$head_node_ip:$head_node_port\n"
"export ip_head\n"
'echo "IP Head: $ip_head"\n\n'
'echo "Starting HEAD at $head_node"\n'
'srun --nodes=1 --ntasks=1 -w "$head_node" \\'
)

if self.params["venv"] == "singularity":
server_script.append(
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} "
"--containall $SINGULARITY_IMAGE \\"
)

server_script.append(""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &

sleep 10
worker_num=$((SLURM_JOB_NUM_NODES - 1))

for ((i = 1; i <= worker_num; i++)); do
node_i=${nodes_array[$i]}
echo "Starting WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w "$node_i" \\""")
server_script.append(
' ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\\n'
' --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &\n\n'
"sleep 10\n"
"worker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n"
"for ((i = 1; i <= worker_num; i++)); do\n"
" node_i=${nodes_array[$i]}\n"
' echo "Starting WORKER $i at $node_i"\n'
' srun --nodes=1 --ntasks=1 -w "$node_i" \\'
)

if self.params["venv"] == "singularity":
server_script.append(
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} "
"--containall $SINGULARITY_IMAGE \\"
)
server_script.append(""" ray start --address "$ip_head" \\
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
sleep 5
done

vllm_port_number=$(find_available_port $head_node_ip 8080 65535)

server_address="http://${head_node_ip}:${vllm_port_number}/v1"
echo "Server address: $server_address"

""")
server_script.append(
' ray start --address "$ip_head" \\\n'
' --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &\n'
" sleep 5\n"
"done\n\n"
"vllm_port_number=$(find_available_port $head_node_ip 8080 65535)\n\n"
'server_address="http://${head_node_ip}:${vllm_port_number}/v1"\n'
'echo "Server address: $server_address"\n\n'
)
return "\n".join(server_script)

def _generate_launcher(self) -> str:
"""Generate the vLLM server launch command.

Creates the command to launch the vLLM server, handling different virtualization
environments (venv or singularity).

Returns
-------
str: Server launch command
"""
if self.params["venv"] == "singularity":
launcher_script = [
f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} --containall $SINGULARITY_IMAGE \\"""
]
else:
launcher_script = [f"""source {self.params["venv"]}/bin/activate"""]
Expand All @@ -168,6 +241,14 @@ def _generate_launcher(self) -> str:
return "\n".join(launcher_script)

def write_to_log_dir(self) -> Path:
"""Write the generated SLURM script to the log directory.

Creates a timestamped script file in the configured log directory.

Returns
-------
Path: Path to the generated SLURM script file
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
script_path: Path = (
Path(self.params["log_dir"])
Expand Down
Loading