Skip to content

Commit 2584797

Browse files
committed
Add stop_token, Modify GPU inference, Support hybrid
1 parent 27b0045 commit 2584797

File tree

12 files changed

+247
-79
lines changed

12 files changed

+247
-79
lines changed

examples/lpu_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"Hello, my name is"
66
]
77
# Create a sampling params object.
8-
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, top_k=1, min_tokens=30, max_tokens=30)
8+
sampling_params = SamplingParams(temperature=0.8, top_p=0.8, top_k=1, repetition_penalty=1.2, max_tokens=60)
99

1010
# Create an LLM.
1111
#llm = LLM(model="facebook/opt-1.3b", device="fpga", pipeline_parallel_size=2)
12-
llm = LLM(model="meta-llama/Meta-Llama-3-8B", device="fpga", tensor_parallel_size=1)
12+
#llm = LLM(model="meta-llama/Meta-Llama-3-8B", device="fpga", tensor_parallel_size=1)
1313
#llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device="fpga", tensor_parallel_size=1)
14+
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device="fpga", num_lpu_devices=2, num_gpu_devices=1)
1415

1516
# Generate texts from the prompts. The output is a list of RequestOutput objects
1617
# that contain the prompt, generated text, and other information.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from openai import OpenAI
2+
3+
# Modify OpenAI's API key and API base to use vLLM's API server.
4+
openai_api_key = "EMPTY"
5+
openai_api_base = "http://localhost:8000/v1"
6+
7+
client = OpenAI(
8+
api_key=openai_api_key,
9+
base_url=openai_api_base,
10+
)
11+
12+
models = client.models.list()
13+
model = models.data[0].id
14+
15+
# Completion API
16+
stream=True
17+
prompt="Hello, my name is"
18+
completion = client.completions.create(
19+
model=model,
20+
prompt=prompt,
21+
stream=stream,
22+
)
23+
24+
print("Prompt:", prompt)
25+
print("Completion results:")
26+
if stream:
27+
for c in completion:
28+
print(c.choices[0].text, end="")
29+
print()
30+
else:
31+
print(completion)

examples/mini_testbench.sh

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ current_datetime=$(date "+%Y-%m-%d %H:%M:%S")
88
echo "$current_datetime"
99
echo "$current_datetime" >> ${log_sum}
1010

11+
# LLMEngine Test
1112
for model_id in "${model_ids[@]}"; do
1213
for num_device in "${num_devices[@]}"; do
1314
#IFS='\' read -ra parts <<< "$model_id"
@@ -24,6 +25,7 @@ for model_id in "${model_ids[@]}"; do
2425
done
2526
done
2627

28+
# LLMEngineAsync Test with vLLM serve
2729
for model_id in "${model_ids[@]}"; do
2830
for num_device in "${num_devices[@]}"; do
2931
model_name=$(echo "$model_id" | awk -F'/' '{print $NF}')
@@ -65,3 +67,44 @@ for model_id in "${model_ids[@]}"; do
6567
done
6668

6769

70+
71+
# OpenAI API Test
72+
model_id=${model_ids[0]}
73+
num_device=${num_devices[0]}
74+
model_name=$(echo "$model_id" | awk -F'/' '{print $NF}')
75+
echo "*********************************"
76+
echo "**** Start serving_${model_name}_${num_device}"
77+
echo "*********************************"
78+
python -m vllm.entrypoints.api_server --model ${model_id} --device fpga --tensor-parallel-size ${num_device} &
79+
80+
# Waiting for server
81+
while ! nc -z localhost "8000"; do
82+
echo "[Testbench] Waiting for server..."
83+
sleep 3
84+
done
85+
echo "[Testbench] The server is ready!"
86+
87+
python lpu_openai_completion_client.py > log/openai_serve_${model_name}_${num_device}.txt
88+
89+
# Waiting for process kill
90+
PID=$(jobs -p | tail -n 1)
91+
if [ -n "$PID" ]; then
92+
kill -SIGINT "$PID"
93+
while true; do
94+
if ps -p "$PID" > /dev/null; then
95+
echo "[Testbench] Kill the process..."
96+
sleep 3
97+
else
98+
echo "[Testbench] Process (PID: $PID) is killed."
99+
break
100+
fi
101+
done
102+
fi
103+
104+
# Write log in text file
105+
echo "*********************************" >> ${log_sum}
106+
echo "The Result of log/openai_serve_${model_name}_${num_device}.txt" >> ${log_sum}
107+
tail -n 1 "log/openai_serve_${model_name}_${num_device}.txt" >> ${log_sum}
108+
echo "" >> ${log_sum}
109+
110+

examples/openai_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python -m vllm.entrypoints.openai.api_server --model facebook/opt-1.3b --device fpga --tensor-parallel-size 2

examples/vllm_serve.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
#python -m vllm.entrypoints.api_server --model facebook/opt-1.3b --device fpga --tensor-parallel-size 2
3+
python -m vllm.entrypoints.api_server --model facebook/opt-1.3b --device fpga --num-gpu-devices 1 --num-lpu-devices 2
4+
#python -m vllm.entrypoints.api_server --model facebook/opt-1.3b --device fpga --num_gpu_devices 1 --num_lpu_devices 2

vllm/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,14 +1011,18 @@ def is_multi_step(self) -> bool:
10111011

10121012
class DeviceConfig:
10131013
device: Optional[torch.device]
1014+
num_gpu_devices: int
1015+
num_lpu_devices: int
10141016

1015-
def __init__(self, device: str = "auto") -> None:
1017+
def __init__(self, device: str = "auto", num_gpu_devices: int = 0, num_lpu_devices: int = 1) -> None:
10161018
if device == "auto":
10171019
# Automated device type detection
10181020
if is_neuron():
10191021
self.device_type = "neuron"
10201022
elif is_openvino():
10211023
self.device_type = "openvino"
1024+
elif current_platform.is_lpu():
1025+
self.device_type = "fpga"
10221026
elif current_platform.is_tpu():
10231027
self.device_type = "tpu"
10241028
elif is_cpu():
@@ -1042,6 +1046,8 @@ def __init__(self, device: str = "auto") -> None:
10421046
# Set device with device type
10431047
self.device = torch.device(self.device_type)
10441048

1049+
self.num_gpu_devices=num_gpu_devices
1050+
self.num_lpu_devices=num_lpu_devices
10451051

10461052
class SpeculativeConfig:
10471053
"""Configuration for speculative decoding.

vllm/core/scheduler.py

Lines changed: 118 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
88
Tuple, Union)
99

10-
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
10+
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig, DeviceConfig
1111
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
1212
from vllm.logger import init_logger, print_logger
1313
from vllm.lora.request import LoRARequest
@@ -301,6 +301,7 @@ def __init__(
301301
scheduler_config: SchedulerConfig,
302302
cache_config: CacheConfig,
303303
lora_config: Optional[LoRAConfig],
304+
device_config: Optional[DeviceConfig],
304305
pipeline_parallel_size: int = 1,
305306
output_proc_callback: Optional[Callable] = None,
306307
) -> None:
@@ -310,6 +311,10 @@ def __init__(
310311
# simple and NOT fair. It can lead to starvation of some
311312
# LoRAs. This should be improved in the future.
312313
self.lora_config = lora_config
314+
# NOTE(hyunjun): Currently, LPU vLLM backend needs to reduce scheduler dependency
315+
# _can_append_slots, _append_slots
316+
# Temporally, we change resource management flow with device config
317+
self.device_config = device_config
313318

314319
version = "v1"
315320
if self.scheduler_config.use_v2_block_manager:
@@ -576,63 +581,119 @@ def _schedule_running(
576581
assert self.output_proc_callback is not None
577582
self.output_proc_callback()
578583
self.running = tmp
579-
580-
while not True: #TODO #self._can_append_slots(seq_group):
581-
budget.subtract_num_batched_tokens(seq_group.request_id,
582-
num_running_tokens)
583-
num_running_seqs = seq_group.get_max_num_running_seqs()
584-
budget.subtract_num_seqs(seq_group.request_id,
585-
num_running_seqs)
586-
587-
if (curr_loras is not None and seq_group.lora_int_id > 0
588-
and seq_group.lora_int_id in curr_loras):
589-
curr_loras.remove(seq_group.lora_int_id)
590-
591-
if running_queue:
592-
# Preempt the lowest-priority sequence groups.
593-
victim_seq_group = running_queue.pop()
594-
preempted_mode = self._preempt(victim_seq_group,
595-
blocks_to_swap_out)
596-
if preempted_mode == PreemptionMode.RECOMPUTE:
597-
preempted.append(victim_seq_group)
598-
else:
599-
swapped_out.append(victim_seq_group)
600-
else:
601-
# No other sequence groups can be preempted.
602-
# Preempt the current sequence group.
603-
preempted_mode = self._preempt(seq_group,
604-
blocks_to_swap_out)
605-
if preempted_mode == PreemptionMode.RECOMPUTE:
606-
preempted.append(seq_group)
607-
else:
608-
swapped_out.append(seq_group)
609-
break
584+
if self.device_config.device_type == "fpga":
585+
while not True: #self._can_append_slots(seq_group):
586+
budget.subtract_num_batched_tokens(seq_group.request_id,
587+
num_running_tokens)
588+
num_running_seqs = seq_group.get_max_num_running_seqs()
589+
budget.subtract_num_seqs(seq_group.request_id,
590+
num_running_seqs)
591+
592+
if (curr_loras is not None and seq_group.lora_int_id > 0
593+
and seq_group.lora_int_id in curr_loras):
594+
curr_loras.remove(seq_group.lora_int_id)
595+
596+
if running_queue:
597+
# Preempt the lowest-priority sequence groups.
598+
victim_seq_group = running_queue.pop()
599+
preempted_mode = self._preempt(victim_seq_group,
600+
blocks_to_swap_out)
601+
if preempted_mode == PreemptionMode.RECOMPUTE:
602+
preempted.append(victim_seq_group)
603+
else:
604+
swapped_out.append(victim_seq_group)
605+
else:
606+
# No other sequence groups can be preempted.
607+
# Preempt the current sequence group.
608+
preempted_mode = self._preempt(seq_group,
609+
blocks_to_swap_out)
610+
if preempted_mode == PreemptionMode.RECOMPUTE:
611+
preempted.append(seq_group)
612+
else:
613+
swapped_out.append(seq_group)
614+
break
615+
else:
616+
is_prefill = seq_group.is_prefill()
617+
scheduled_seq_group: ScheduledSequenceGroup = \
618+
self._scheduled_seq_group_cache[self.cache_id].get_object()
619+
scheduled_seq_group.seq_group = seq_group
620+
if is_prefill:
621+
scheduled_seq_group.token_chunk_size = num_running_tokens
622+
prefill_seq_groups.append(scheduled_seq_group)
623+
ret.prefill_seq_groups_list.append(seq_group)
624+
else:
625+
scheduled_seq_group.token_chunk_size = 1
626+
decode_seq_groups.append(scheduled_seq_group)
627+
ret.decode_seq_groups_list.append(seq_group)
628+
629+
budget.add_num_batched_tokens(seq_group.request_id,
630+
num_running_tokens)
631+
# OPTIMIZATION: Note that get_max_num_running_seqs is
632+
# expensive. For the default scheduling chase where
633+
# enable_chunking is False, num_seqs are updated before running
634+
# this method, so we don't have to update it again here.
635+
if enable_chunking:
636+
num_running_seqs = seq_group.get_max_num_running_seqs()
637+
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
638+
if curr_loras is not None and seq_group.lora_int_id > 0:
639+
curr_loras.add(seq_group.lora_int_id)
610640
else:
611-
#self._append_slots(seq_group, blocks_to_copy)
612-
is_prefill = seq_group.is_prefill()
613-
scheduled_seq_group: ScheduledSequenceGroup = \
614-
self._scheduled_seq_group_cache[self.cache_id].get_object()
615-
scheduled_seq_group.seq_group = seq_group
616-
if is_prefill:
617-
scheduled_seq_group.token_chunk_size = num_running_tokens
618-
prefill_seq_groups.append(scheduled_seq_group)
619-
ret.prefill_seq_groups_list.append(seq_group)
620-
else:
621-
scheduled_seq_group.token_chunk_size = 1
622-
decode_seq_groups.append(scheduled_seq_group)
623-
ret.decode_seq_groups_list.append(seq_group)
624-
625-
budget.add_num_batched_tokens(seq_group.request_id,
626-
num_running_tokens)
627-
# OPTIMIZATION: Note that get_max_num_running_seqs is
628-
# expensive. For the default scheduling chase where
629-
# enable_chunking is False, num_seqs are updated before running
630-
# this method, so we don't have to update it again here.
631-
if enable_chunking:
632-
num_running_seqs = seq_group.get_max_num_running_seqs()
633-
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
634-
if curr_loras is not None and seq_group.lora_int_id > 0:
635-
curr_loras.add(seq_group.lora_int_id)
641+
while not self._can_append_slots(seq_group):
642+
budget.subtract_num_batched_tokens(seq_group.request_id,
643+
num_running_tokens)
644+
num_running_seqs = seq_group.get_max_num_running_seqs()
645+
budget.subtract_num_seqs(seq_group.request_id,
646+
num_running_seqs)
647+
648+
if (curr_loras is not None and seq_group.lora_int_id > 0
649+
and seq_group.lora_int_id in curr_loras):
650+
curr_loras.remove(seq_group.lora_int_id)
651+
652+
if running_queue:
653+
# Preempt the lowest-priority sequence groups.
654+
victim_seq_group = running_queue.pop()
655+
preempted_mode = self._preempt(victim_seq_group,
656+
blocks_to_swap_out)
657+
if preempted_mode == PreemptionMode.RECOMPUTE:
658+
preempted.append(victim_seq_group)
659+
else:
660+
swapped_out.append(victim_seq_group)
661+
else:
662+
# No other sequence groups can be preempted.
663+
# Preempt the current sequence group.
664+
preempted_mode = self._preempt(seq_group,
665+
blocks_to_swap_out)
666+
if preempted_mode == PreemptionMode.RECOMPUTE:
667+
preempted.append(seq_group)
668+
else:
669+
swapped_out.append(seq_group)
670+
break
671+
else:
672+
self._append_slots(seq_group, blocks_to_copy)
673+
is_prefill = seq_group.is_prefill()
674+
scheduled_seq_group: ScheduledSequenceGroup = \
675+
self._scheduled_seq_group_cache[self.cache_id].get_object()
676+
scheduled_seq_group.seq_group = seq_group
677+
if is_prefill:
678+
scheduled_seq_group.token_chunk_size = num_running_tokens
679+
prefill_seq_groups.append(scheduled_seq_group)
680+
ret.prefill_seq_groups_list.append(seq_group)
681+
else:
682+
scheduled_seq_group.token_chunk_size = 1
683+
decode_seq_groups.append(scheduled_seq_group)
684+
ret.decode_seq_groups_list.append(seq_group)
685+
686+
budget.add_num_batched_tokens(seq_group.request_id,
687+
num_running_tokens)
688+
# OPTIMIZATION: Note that get_max_num_running_seqs is
689+
# expensive. For the default scheduling chase where
690+
# enable_chunking is False, num_seqs are updated before running
691+
# this method, so we don't have to update it again here.
692+
if enable_chunking:
693+
num_running_seqs = seq_group.get_max_num_running_seqs()
694+
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
695+
if curr_loras is not None and seq_group.lora_int_id > 0:
696+
curr_loras.add(seq_group.lora_int_id)
636697

637698
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
638699
self._scheduled_seq_group_cache[self.next_cache_id].reset()

vllm/engine/arg_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
PromptAdapterConfig, SchedulerConfig,
1515
SpeculativeConfig, TokenizerPoolConfig)
1616
from vllm.executor.executor_base import ExecutorBase
17-
from vllm.logger import init_logger
17+
from vllm.logger import init_logger, print_logger
1818
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1919
from vllm.utils import FlexibleArgumentParser
2020

@@ -149,6 +149,10 @@ class EngineArgs:
149149
collect_detailed_traces: Optional[str] = None
150150
disable_async_output_proc: bool = False
151151

152+
#NOTE(hyunjun): custom option for hybrid
153+
num_gpu_devices: int = 0
154+
num_lpu_devices: int = 1
155+
152156
def __post_init__(self):
153157
if self.tokenizer is None:
154158
self.tokenizer = self.model
@@ -741,6 +745,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
741745
default=EngineArgs.disable_async_output_proc,
742746
help="Disable async output processing. This may result in "
743747
"lower performance.")
748+
parser.add_argument(
749+
'--num-gpu-devices',
750+
type=int,
751+
default=0,
752+
help='the number of gpu devices for hybrid system')
753+
parser.add_argument(
754+
'--num-lpu-devices',
755+
type=int,
756+
default=1,
757+
help='the number of lpu devices for hybrid system')
744758
return parser
745759

746760
@classmethod
@@ -775,8 +789,7 @@ def create_engine_config(self) -> EngineConfig:
775789
assert self.cpu_offload_gb >= 0, (
776790
"CPU offload space must be non-negative"
777791
f", but got {self.cpu_offload_gb}")
778-
779-
device_config = DeviceConfig(device=self.device)
792+
device_config = DeviceConfig(device=self.device, num_gpu_devices=self.num_gpu_devices, num_lpu_devices=self.num_lpu_devices)
780793
model_config = ModelConfig(
781794
model=self.model,
782795
tokenizer=self.tokenizer,

0 commit comments

Comments
 (0)