Skip to content

Commit e6d28c8

Browse files
authored
refactor: dispatcher as forwarding decision maker for llm (#305)
llm serving is stateful. when a request is served in a distributed manner, it needs to be routed to the same set of workers for efficient decoding. In particular, when a token output needs to go back to the first layer, it has to go back to the worker that holds KV cache. Otherwise, decoding may be done incorrectly if not it's broken. However, when a serving pipeline is constructed as a mesh, the current implementation doesn't guarantee a correct forwarding. To allow correct forwarding, we make a server (dispatcher) work as a forwarding decision maker for llm. Since there is only one dispatcher, the workers of the last stage can deterministically forward a generated token back to the dispatcher. Then, the dispatcher determines whether the token needs to be sent back to the first stage or not.
1 parent 68deaea commit e6d28c8

File tree

12 files changed

+258
-47
lines changed

12 files changed

+258
-47
lines changed

examples/llama3/auto/linear-no-recover.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ flow_graph:
1919
0-0:
2020
- name: w1
2121
peers: [s-0]
22-
- name: w2
23-
peers: [2-0]
2422
1-0:
2523
- name: w3
2624
peers: [0-0]

examples/llama3/auto/linear.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ flow_graph:
1818
0-0:
1919
- name: w1
2020
peers: [s-0]
21-
- name: w2
22-
peers: [2-0]
2321
1-0:
2422
- name: w3
2523
peers: [0-0]

examples/llama3/auto/linear_p3.8xlarge.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ flow_graph:
1818
0-0:
1919
- name: w1
2020
peers: [s-0]
21-
- name: w2
22-
peers: [2-0]
2321
1-0:
2422
- name: w3
2523
peers: [0-0]

examples/llama3/static/linear-no-recover.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ flow_graph:
2424
peers: [s-0]
2525
addr: 10.20.1.50
2626
backend: gloo
27-
- name: w2
28-
peers: [2-0]
29-
addr: 10.20.1.50
30-
backend: gloo
3127
1-0:
3228
- name: w3
3329
peers: [0-0]

examples/llama3/static/linear.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ flow_graph:
2323
peers: [s-0]
2424
addr: 10.20.1.50
2525
backend: gloo
26-
- name: w2
27-
peers: [2-0]
28-
addr: 10.20.1.50
29-
backend: gloo
3026
1-0:
3127
- name: w3
3228
peers: [0-0]

examples/llama3/static/linear_p3.8xlarge.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,22 @@ flow_graph:
1616
s-0:
1717
- name: w0
1818
peers: [2-0]
19-
addr: 10.20.1.50
19+
addr: 10.20.1.72
2020
backend: nccl
2121
0-0:
2222
- name: w1
2323
peers: [s-0]
24-
addr: 10.20.1.50
25-
backend: nccl
26-
- name: w2
27-
peers: [2-0]
28-
addr: 10.20.1.50
24+
addr: 10.20.1.72
2925
backend: nccl
3026
1-0:
3127
- name: w3
3228
peers: [0-0]
33-
addr: 10.20.1.50
29+
addr: 10.20.1.72
3430
backend: nccl
3531
2-0:
3632
- name: w4
3733
peers: [1-0]
38-
addr: 10.20.1.50
34+
addr: 10.20.1.72
3935
backend: nccl
4036

4137
dataset: # huggingface dataset

examples/llama3/static/mesh.yaml

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
---
2+
name: llama3 linear example
3+
model: meta-llama/Meta-Llama-3.1-8B
4+
nfaults: 1
5+
# the following entries are for local development only
6+
# after development is done, they need to be revised accordingly
7+
# to automate building flow_graph, etc.
8+
micro_batch_size: 2
9+
fwd_policy: rr
10+
job_id: "job5"
11+
# maximum number of requests in flight at any given point in time
12+
max_inflight: 4
13+
14+
# Note: IP addresses should be agents'
15+
flow_graph:
16+
s-0:
17+
- name: w4
18+
peers: [2-0]
19+
addr: 10.20.1.72
20+
backend: nccl
21+
- name: w9
22+
peers: [5-0]
23+
addr: 10.20.1.72
24+
backend: nccl
25+
0-0:
26+
- name: w0
27+
peers: [s-0]
28+
addr: 10.20.1.72
29+
backend: nccl
30+
1-0:
31+
- name: w1
32+
peers: [0-0]
33+
addr: 10.20.1.72
34+
backend: nccl
35+
- name: w10
36+
peers: [3-0]
37+
addr: 10.20.1.72
38+
backend: nccl
39+
2-0:
40+
- name: w2
41+
peers: [1-0]
42+
addr: 10.20.1.72
43+
backend: nccl
44+
- name: w11
45+
peers: [4-0]
46+
addr: 10.20.1.72
47+
backend: nccl
48+
3-0:
49+
- name: w5
50+
peers: [s-0]
51+
addr: 10.20.1.50
52+
backend: nccl
53+
4-0:
54+
- name: w6
55+
peers: [3-0]
56+
addr: 10.20.1.50
57+
backend: nccl
58+
- name: w12
59+
peers: [0-0]
60+
addr: 10.20.1.50
61+
backend: nccl
62+
5-0:
63+
- name: w7
64+
peers: [4-0]
65+
addr: 10.20.1.50
66+
backend: nccl
67+
- name: w13
68+
peers: [1-0]
69+
addr: 10.20.1.50
70+
backend: nccl
71+
dataset: # huggingface dataset
72+
path: fka/awesome-chatgpt-prompts
73+
name: ""
74+
split: train
75+
76+
workers:
77+
- id: s-0
78+
device: cuda:0
79+
is_server: True
80+
stage:
81+
start: -1
82+
end: -1
83+
- id: 0-0
84+
device: cuda:1
85+
stage:
86+
start: 0
87+
end: 10
88+
- id: 1-0
89+
device: cuda:2
90+
stage:
91+
start: 11
92+
end: 23
93+
- id: 2-0
94+
device: cuda:3
95+
stage:
96+
start: 24
97+
end: 34
98+
- id: 3-0
99+
device: cuda:1
100+
stage:
101+
start: 0
102+
end: 10
103+
- id: 4-0
104+
device: cuda:2
105+
stage:
106+
start: 11
107+
end: 23
108+
- id: 5-0
109+
device: cuda:3
110+
stage:
111+
start: 24
112+
end: 34
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
---
2+
name: llama3 linear example
3+
model: meta-llama/Meta-Llama-3.1-8B
4+
nfaults: 1
5+
# the following entries are for local development only
6+
# after development is done, they need to be revised accordingly
7+
# to automate building flow_graph, etc.
8+
micro_batch_size: 2
9+
fwd_policy: rr
10+
job_id: "job5"
11+
# maximum number of requests in flight at any given point in time
12+
max_inflight: 4
13+
14+
# Note: IP addresses should be agents'
15+
flow_graph:
16+
s-0:
17+
- name: w4
18+
peers: [2-0]
19+
addr: 10.20.1.72
20+
backend: nccl
21+
- name: w9
22+
peers: [5-0]
23+
addr: 10.20.1.72
24+
backend: nccl
25+
0-0:
26+
- name: w0
27+
peers: [s-0]
28+
addr: 10.20.1.72
29+
backend: nccl
30+
1-0:
31+
- name: w1
32+
peers: [0-0]
33+
addr: 10.20.1.72
34+
backend: nccl
35+
2-0:
36+
- name: w2
37+
peers: [1-0]
38+
addr: 10.20.1.72
39+
backend: nccl
40+
3-0:
41+
- name: w5
42+
peers: [s-0]
43+
addr: 10.20.1.50
44+
backend: nccl
45+
4-0:
46+
- name: w6
47+
peers: [3-0]
48+
addr: 10.20.1.50
49+
backend: nccl
50+
5-0:
51+
- name: w7
52+
peers: [4-0]
53+
addr: 10.20.1.50
54+
backend: nccl
55+
56+
dataset: # huggingface dataset
57+
path: fka/awesome-chatgpt-prompts
58+
name: ""
59+
split: train
60+
61+
workers:
62+
- id: s-0
63+
device: cuda:0
64+
is_server: True
65+
stage:
66+
start: -1
67+
end: -1
68+
- id: 0-0
69+
device: cuda:1
70+
stage:
71+
start: 0
72+
end: 10
73+
- id: 1-0
74+
device: cuda:2
75+
stage:
76+
start: 11
77+
end: 23
78+
- id: 2-0
79+
device: cuda:3
80+
stage:
81+
start: 24
82+
end: 34
83+
- id: 3-0
84+
device: cuda:1
85+
stage:
86+
start: 0
87+
end: 10
88+
- id: 4-0
89+
device: cuda:2
90+
stage:
91+
start: 11
92+
end: 23
93+
- id: 5-0
94+
device: cuda:3
95+
stage:
96+
start: 24
97+
end: 34

infscale/controller/cfggen.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -841,19 +841,24 @@ def _build_flow_graph(self):
841841
)
842842
world_id += 1
843843

844-
# For LLM(e.g., llama), add feedback connections from last stage to first stage
845-
if self._is_auto_regressive and i == 0 and len(stage_ids) > 1:
846-
last_stage_workers = stages[stage_ids[-1]]
847-
for last_worker in last_stage_workers:
848-
connections.append(
849-
{
850-
"name": f"w{world_id}",
851-
"peers": FlowList([last_worker["id"]]),
852-
"addr": worker_addr,
853-
"backend": "nccl",
854-
}
855-
)
856-
world_id += 1
844+
# WE DON'T NEED A FEEDBACK CONNECTION ANY MORE SINCE THE
845+
# DISPATCHER HANDLES THAT.
846+
# WE KEEP THE COMMENTED-OUT CODE JUST IN CASE WE HAVE TO
847+
# REVERT IT IN THE FUTURE
848+
#
849+
# # For LLM(e.g., llama), add feedback connections from last stage to first stage
850+
# if self._is_auto_regressive and i == 0 and len(stage_ids) > 1:
851+
# last_stage_workers = stages[stage_ids[-1]]
852+
# for last_worker in last_stage_workers:
853+
# connections.append(
854+
# {
855+
# "name": f"w{world_id}",
856+
# "peers": FlowList([last_worker["id"]]),
857+
# "addr": worker_addr,
858+
# "backend": "nccl",
859+
# }
860+
# )
861+
# world_id += 1
857862

858863
flow_graph[worker_id] = connections
859864

infscale/execution/router.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, world_manager: WorldManager, mc: MetricsCollector):
6666
_ = asyncio.create_task(self._recv_arbiter())
6767

6868
self._fwder: Forwarder = None
69+
self._is_server = False
6970

7071
@property
7172
def rx_q(self) -> asyncio.Queue:
@@ -115,6 +116,7 @@ async def configure(
115116
worlds_to_remove: list[WorldInfo] = [],
116117
) -> None:
117118
"""(Re)configure router."""
119+
self._is_server = spec.is_server
118120
self.device = device
119121

120122
if self._fwder is None:
@@ -314,8 +316,26 @@ async def _recv_arbiter(self) -> None:
314316
while True:
315317
try:
316318
tensor, seqno = await self.__rx_q.get()
317-
# TODO: introduce a prioritization policy
318-
await self._rx_q.put((tensor, seqno))
319+
320+
if (
321+
self._is_server
322+
and self._fwder.is_sticky()
323+
and "tokens" not in tensor
324+
):
325+
# if router is configured for server (i.e., dispatcher),
326+
# we need to check the following:
327+
#
328+
# if the model is llm (i.e., is_sticky()), we need to check
329+
# if decoding is done (if "tokens" are in tensor or not):
330+
# if not, we have to return the tensor to the stage with
331+
# layer 0 to continue decoding
332+
# For that, we put the tensor back to _tx_q so that
333+
# send_arbiter() takes care of sending the tensor to
334+
# a correct stage (i.e., worker) whose start layer is 0.
335+
await self._tx_q.put((seqno, tensor, 0))
336+
else:
337+
# TODO: introduce a prioritization policy
338+
await self._rx_q.put((tensor, seqno))
319339
except Exception as e:
320340
# this is very likely to be a no-op due to the simple
321341
# get and put operations we do on the asyncio queues.

0 commit comments

Comments
 (0)