Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "Efficient LLM inference on Slurm clusters using vLLM."
readme = "README.md"
authors = [{name = "Marshall Wang", email = "marshall.wang@vectorinstitute.ai"}]
license = "MIT"
requires-python = ">=3.10"
requires-python = ">=3.10,<4.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but when I didn't have it, it was trying to use Python 3.12 and ran into dependency issues. That's why I changed it. But I guess we can leave that for another PR if it's still a problem.
I was still using Poetry.

dependencies = [
"requests>=2.31.0",
"click>=8.1.0",
Expand Down
41 changes: 11 additions & 30 deletions vec_inf/cli/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import vec_inf.cli._utils as utils
from vec_inf.cli._config import ModelConfig
from vec_inf.cli._slurm_script_generator import SlurmScriptGenerator


VLLM_TASK_MAP = {
Expand Down Expand Up @@ -127,31 +128,7 @@ def _get_launch_params(self) -> dict[str, Any]:

def set_env_vars(self) -> None:
"""Set environment variables for the launch command."""
os.environ["MODEL_NAME"] = self.model_name
os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"]
os.environ["MAX_LOGPROBS"] = self.params["vocab_size"]
os.environ["DATA_TYPE"] = self.params["data_type"]
os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"]
os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"]
os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]]
os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"]
os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"]
os.environ["SRC_DIR"] = SRC_DIR
os.environ["MODEL_WEIGHTS"] = str(
Path(self.params["model_weights_parent_dir"], self.model_name)
)
os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
os.environ["VENV_BASE"] = self.params["venv"]
os.environ["LOG_DIR"] = self.params["log_dir"]

if self.params.get("enable_prefix_caching"):
os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
if self.params.get("enable_chunked_prefill"):
os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
if self.params.get("max_num_batched_tokens"):
os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
if self.params.get("enforce_eager"):
os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]

def build_launch_command(self) -> str:
"""Construct the full launch command with parameters."""
Expand All @@ -176,11 +153,12 @@ def build_launch_command(self) -> str:
f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err",
]
)
# Add slurm script
slurm_script = "vllm.slurm"
if int(self.params["num_nodes"]) > 1:
slurm_script = "multinode_vllm.slurm"
command_list.append(f"{SRC_DIR}/{slurm_script}")

slurm_script_path = SlurmScriptGenerator(
self.params, src_dir=SRC_DIR
).write_to_log_dir()

command_list.append(str(slurm_script_path))
return " ".join(command_list)

def format_table_output(self, job_id: str) -> Table:
Expand Down Expand Up @@ -214,7 +192,10 @@ def format_table_output(self, job_id: str) -> Table:
)
if self.params.get("enforce_eager"):
table.add_row("Enforce Eager", self.params["enforce_eager"])
table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS"))
table.add_row(
"Model Weights Directory",
str(Path(self.params["model_weights_parent_dir"], self.model_name)),
)
table.add_row("Log Directory", self.params["log_dir"])

return table
Expand Down
194 changes: 194 additions & 0 deletions vec_inf/cli/_slurm_script_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from datetime import datetime
from pathlib import Path
from typing import Any


VLLM_TASK_MAP = {
"LLM": "generate",
"VLM": "generate",
"Text_Embedding": "embed",
"Reward_Modeling": "reward",
}


class SlurmScriptGenerator:
def __init__(self, params: dict[str, Any], src_dir: str):
self.params = params
self.src_dir = src_dir
self.is_multinode = int(self.params["num_nodes"]) > 1
self.model_weights_path = str(
Path(params["model_weights_parent_dir"], params["model_name"])
)
self.task = VLLM_TASK_MAP[self.params["model_type"]]

def _generate_script_content(self) -> str:
preamble = self._generate_preamble()
server = self._generate_server_script()
env_exports = self._export_parallel_vars()
launcher = self._generate_launcher()
args = self._generate_shared_args()
return preamble + server + env_exports + launcher + args

def _generate_preamble(self) -> str:
base = [
"#!/bin/bash",
"#SBATCH --cpus-per-task=16",
"#SBATCH --mem=64G",
]
if self.is_multinode:
base += [
"#SBATCH --exclusive",
"#SBATCH --tasks-per-node=1",
]
base += [""]
return "\n".join(base)

def _export_parallel_vars(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only being exported so that they can be consumed when generating the shared args right? Can we just populate the parameters directly instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, will remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to tag, fixed this.

if self.is_multinode:
return """if [ "$PIPELINE_PARALLELISM" = "True" ]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're generating these scripts dynamically, we shouldn't need to have any if else statements in the generated scripts. This logic should go directly into _generate_shared_args, which it looks like you have the same checks in that function already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to tag, fixed this.

export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
else
export PIPELINE_PARALLEL_SIZE=1
export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
fi

"""
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n\n"

def _generate_shared_args(self) -> str:
args = [
f"--model {self.model_weights_path} \\",
f"--served-model-name {self.params['model_name']} \\",
'--host "0.0.0.0" \\',
"--port $vllm_port_number \\",
"--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \\",
f"--dtype {self.params['data_type']} \\",
"--trust-remote-code \\",
f"--max-logprobs {self.params['vocab_size']} \\",
f"--max-model-len {self.params['max_model_len']} \\",
f"--max-num-seqs {self.params['max_num_seqs']} \\",
f"--gpu-memory-utilization {self.params['gpu_memory_utilization']} \\",
f"--compilation-config {self.params['compilation_config']} \\",
f"--task {self.task} \\",
]
if self.is_multinode:
args.insert(4, "--pipeline-parallel-size ${PIPELINE_PARALLEL_SIZE} \\")
if self.params.get("max_num_batched_tokens"):
args.append(
f"--max-num-batched-tokens={self.params['max_num_batched_tokens']} \\"
)
if self.params.get("enable_prefix_caching") == "True":
args.append("--enable-prefix-caching \\")
if self.params.get("enable_chunked_prefill") == "True":
args.append("--enable-chunked-prefill \\")
if self.params.get("enforce_eager") == "True":
args.append("--enforce-eager")

return "\n".join(args)

def _generate_server_script(self) -> str:
server_script = [""]
if self.params["venv"] == "singularity":
server_script.append("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
module load singularity-ce/3.8.2
singularity exec $SINGULARITY_IMAGE ray stop
""")
server_script.append(f"source {self.src_dir}/find_port.sh\n")
server_script.append(
self._generate_multinode_server_script()
if self.is_multinode
else self._generate_single_node_server_script()
)
server_script.append(f"""echo "Updating server address in $JSON_PATH"
JSON_PATH="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
jq --arg server_addr "$SERVER_ADDR" \\
'. + {{"server_address": $server_addr}}' \\
"$JSON_PATH" > temp.json \\
&& mv temp.json "$JSON_PATH" \\
&& rm -f temp.json

""")
return "\n".join(server_script)

def _generate_single_node_server_script(self) -> str:
return """hostname=${SLURMD_NODENAME}
vllm_port_number=$(find_available_port ${hostname} 8080 65535)

SERVER_ADDR="http://${hostname}:${vllm_port_number}/v1"
echo "Server address: $SERVER_ADDR"
"""

def _generate_multinode_server_script(self) -> str:
server_script = []
server_script.append("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)

head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

head_node_port=$(find_available_port $head_node_ip 8080 65535)

ip_head=$head_node_ip:$head_node_port
export ip_head
echo "IP Head: $ip_head"

echo "Starting HEAD at $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" \\""")

if self.params["venv"] == "singularity":
server_script.append(
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"
)

server_script.append(""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &

sleep 10
worker_num=$((SLURM_JOB_NUM_NODES - 1))

for ((i = 1; i <= worker_num; i++)); do
node_i=${nodes_array[$i]}
echo "Starting WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w "$node_i" \\""")

if self.params["venv"] == "singularity":
server_script.append(
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
)
server_script.append(""" ray start --address "$ip_head" \\
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
sleep 5
done

vllm_port_number=$(find_available_port $head_node_ip 8080 65535)

SERVER_ADDR="http://${head_node_ip}:${vllm_port_number}/v1"
echo "Server address: $SERVER_ADDR"

""")
return "\n".join(server_script)

def _generate_launcher(self) -> str:
if self.params["venv"] == "singularity":
launcher_script = [
f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
]
else:
launcher_script = [f"""source {self.params["venv"]}/bin/activate"""]
launcher_script.append(
"""python3.10 -m vllm.entrypoints.openai.api_server \\\n"""
)
return "\n".join(launcher_script)

def write_to_log_dir(self) -> Path:
log_subdir: Path = Path(self.params["log_dir"]) / self.params["model_name"]
log_subdir.mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
script_path: Path = log_subdir / f"launch_{timestamp}.slurm"

content = self._generate_script_content()
script_path.write_text(content)
return script_path
Loading
Loading