@@ -16,10 +16,6 @@ def __init__(self, backend: ModeBackend):
1616 self .left_decode_num = self .decode_max_step
1717
1818 self .step_count = 0
19-
20- # dp prefill 配平调度的延迟参数。
21- self .dp_prefill_wait_step = 0
22- self .dp_prefill_wait_max_step = get_env_start_args ().dp_prefill_wait_step
2319 return
2420
2521 def select_run_way (
@@ -71,48 +67,23 @@ def _normal_way(
7167 prefill_reqs : List [InferReq ],
7268 decode_reqs : List [InferReq ],
7369 ):
74- """
75- _normal_way 接口用于控制 DP 模式下进行chuncked prefill时,需要考虑各个DP的真实运行请求数量:
76- 考虑 8 个 dp 的场景,如果每个 dp 执行 prefill 的请求的数量分别为: [1, 1, 0, 0, 0, 0, 0, 0], 则在运行
77- 的过程中,请求数量为0的dp会pad一个fake req来参与计算,但是这会导致这些dp因为一些通信同步的原因,造成大量
78- 算力浪费,实际有效率很低。
79- 解决方法:
80- 在判断是否可以进行 prefill 的时候,需要先考虑所有dp的请求数量是否均衡,浪费率是否在可以接受的范围,如果无法
81- 接受这么高的浪费率,则可以延迟 prefill 的执行时机,直到所有dp的浪费率较低时再进行prefill, 不过延迟执行的极限
82- 等待时间,受到 dp_prefill_wait_step 参数的控制。
83- """
84- use_ratio = np .count_nonzero (dp_prefill_req_nums ) / dp_prefill_req_nums .shape [0 ]
70+ # use_ratio = np.count_nonzero(dp_prefill_req_nums) / dp_prefill_req_nums.shape[0]
8571 max_decode_num = np .max (dp_decode_req_nums )
8672 max_prefill_num = np .max (dp_prefill_req_nums )
8773
8874 if self .left_decode_num > 0 and max_decode_num > 0 :
8975 self .left_decode_num -= 1
9076 return RunWay .DECODE
9177
92- if use_ratio < 0.6 :
93- if max_prefill_num > 0 :
94- self .dp_prefill_wait_step += 1
95- if self .dp_prefill_wait_step > self .dp_prefill_wait_max_step :
96- # prefill 一次允许进行几次 decode 操作。
97- self .left_decode_num = self .decode_max_step
98- self .dp_prefill_wait_step = max (0 , (self .dp_prefill_wait_step - self .decode_max_step ))
99- return RunWay .PREFILL
100-
78+ if max_prefill_num > 0 :
79+ # prefill 一次允许进行几次 decode 操作。
80+ self .left_decode_num = self .decode_max_step
81+ return RunWay .PREFILL
82+ else :
10183 if max_decode_num > 0 :
10284 return RunWay .DECODE
10385 else :
10486 return RunWay .PASS
105- else :
106- if max_prefill_num > 0 :
107- self .dp_prefill_wait_step = 0
108- # prefill 一次允许进行几次 decode 操作。
109- self .left_decode_num = self .decode_max_step
110- return RunWay .PREFILL
111- else :
112- if max_decode_num > 0 :
113- return RunWay .DECODE
114- else :
115- return RunWay .PASS
11687
11788 def try_recover_paused_reqs (self ) -> bool :
11889 return self .step_count % 100 == 0
0 commit comments