22use crate :: { exec:: KVCache , op:: random_sample:: SampleArgs } ;
33use log:: warn;
44use std:: {
5+ cmp:: min,
56 collections:: BTreeMap ,
67 iter:: repeat_n,
78 mem:: take,
@@ -12,12 +13,12 @@ use std::{
1213} ;
1314use tokeneer:: utok;
1415
15- #[ derive( Default ) ]
1616pub ( super ) struct EngineManager {
1717 sess : BTreeMap < SessionId , SessionStub > ,
1818 pre_output : BTreeMap < SessionId , usize > ,
1919 // 每次prefill的最大长度
20- chunked_prefill_len : Option < usize > ,
20+ chunked_prefill_max_len : Option < usize > ,
21+ max_toks : usize ,
2122}
2223
2324#[ derive( Default ) ]
@@ -41,10 +42,12 @@ pub enum CommandError {
4142type E = CommandError ;
4243
4344impl EngineManager {
44- pub fn new ( chunked_prefill_len : Option < usize > ) -> Self {
45+ pub fn new ( chunked_prefill_len : Option < usize > , max_toks : usize ) -> Self {
4546 Self {
46- chunked_prefill_len,
47- ..Default :: default ( )
47+ sess : Default :: default ( ) ,
48+ pre_output : Default :: default ( ) ,
49+ chunked_prefill_max_len : chunked_prefill_len,
50+ max_toks,
4851 }
4952 }
5053 /// 接收命令
@@ -82,53 +85,62 @@ impl EngineManager {
8285 let mut out_idx = 0 ;
8386
8487 let pre_output = take ( & mut self . pre_output ) ;
85- for ( id, mut stub) in take ( & mut self . sess ) {
88+
89+ let mut write_back_sessions = BTreeMap :: < SessionId , SessionStub > :: new ( ) ;
90+
91+ while let Some ( ( id, mut stub) ) = self . sess . pop_first ( ) {
8692 let max = stub. session . cache . len ;
8793 let pos = stub. session . cache . pos ;
88- let seq = stub. state . seq ;
89- let out = stub. state . out ;
90- let end = pos + seq;
91- assert_eq ! ( out, 1 , "TODO: ??? " ) ;
94+ let mut seq = stub. state . seq ;
95+ let mut out = stub. state . out ;
96+ let mut end = pos + seq;
97+ assert_eq ! ( out, 1 , "TODO: 投机采样 " ) ;
9298 //验证缓存是否溢出
9399 if end > max {
94100 warn ! ( "cache overflow {end} > {max}" ) ;
95101 // 缓存溢出,不再推理
96102 ans. overflow . push ( stub. session ) ;
97103 continue ;
98104 }
99- //chunked prefill
100- // 采用 state.seq 用于计算剩余需要prefill的长度
105+
106+ // 用于限制每次tokens总数
107+ let remain_tok_num = self . max_toks - ans. tokens . len ( ) ;
108+ assert ! ( remain_tok_num > 0 ) ;
109+
101110 if let Some ( prompt) = & stub. prompt {
102- if let Some ( chunked_prefill_len) = self . chunked_prefill_len {
103- if stub. state . seq > chunked_prefill_len {
104- // 根据chunked_prefill_len重新计算seq和end
105- let seq = chunked_prefill_len;
106- let end = pos + seq;
107- ans. sample . push ( stub. session . sample_args ) ;
108- ans. output . push ( ( id, 0 ) ) ;
109- ans. reqs . push ( Req {
110- kv_cache : stub. session . cache . parts . clone ( ) ,
111- pos,
112- seq,
113- } ) ;
114- ans. tokens . extend (
115- prompt
116- . iter ( )
117- . skip ( prompt. len ( ) - stub. state . seq )
118- . take ( chunked_prefill_len) ,
119- ) ;
120-
121- //更新stub信息
122- stub. session . cache . pos = end;
123- stub. state . seq -= chunked_prefill_len;
124-
125- //回填
126- assert ! ( self . sess. insert( id, stub) . is_none( ) ) ;
127-
128- //提前结束
129- continue ;
111+ seq = self
112+ . chunked_prefill_max_len
113+ . map_or ( min ( remain_tok_num, seq) , |chunked_prefill_max_len| {
114+ remain_tok_num. min ( seq) . min ( chunked_prefill_max_len)
115+ } ) ;
116+
117+ if seq < stub. state . seq {
118+ // chunked prefill
119+ out = 0 ;
120+ end = pos + seq;
121+
122+ ans. tokens
123+ . extend ( prompt. iter ( ) . skip ( prompt. len ( ) - stub. state . seq ) . take ( seq) ) ;
124+
125+ //更新stub信息
126+ stub. state . seq -= seq;
127+ } else {
128+ // 正常prefill
129+ if seq != prompt. len ( ) {
130+ log:: debug!( "{:?} chunked prefil finished" , id) ;
130131 }
132+ ans. tokens . extend ( prompt[ prompt. len ( ) - seq..] . to_owned ( ) ) ;
133+
134+ stub. state . seq = 1 ;
135+ stub. prompt = None ;
131136 }
137+ } else {
138+ // decode
139+ assert_eq ! ( seq, 1 ) ;
140+ // fast embd
141+ ans. fast_map
142+ . push ( ( pre_output[ & id] as _ , ans. tokens . len ( ) as _ ) ) ;
143+ ans. tokens . push ( 0 )
132144 }
133145
134146 // 尝试填充缓存
@@ -141,35 +153,28 @@ impl EngineManager {
141153 pos,
142154 seq,
143155 } ) ;
144- if let Some ( prompt) = stub. prompt . take ( ) {
145- // prefill
146- if seq != prompt. len ( ) {
147- log:: debug!( "{:?} chunked prefil finished" , id) ;
148- }
149- ans. tokens . extend ( prompt[ prompt. len ( ) - seq..] . to_owned ( ) ) ;
150-
151- stub. state . seq = 1
152- } else {
153- // decode
154- assert_eq ! ( seq, 1 ) ;
155- // fast embd
156- ans. fast_map
157- . push ( ( pre_output[ & id] as _ , ans. tokens . len ( ) as _ ) ) ;
158- ans. tokens . push ( 0 )
159- }
160156
161157 //输出处理
162- stub. state . remain_steps -= 1 ;
158+ //不会溢出 因为 out <= 1
159+ stub. state . remain_steps -= out;
163160 if stub. state . remain_steps == 0 {
164161 // 生成结束
165162 ans. finished . push ( stub. session )
166163 } else {
167164 // 回填
168- assert ! ( self . sess. insert( id, stub) . is_none( ) ) ;
169- assert ! ( self . pre_output. insert( id, out_idx) . is_none( ) ) ;
165+ assert ! ( write_back_sessions. insert( id, stub) . is_none( ) ) ;
166+ if out != 0 {
167+ assert ! ( self . pre_output. insert( id, out_idx) . is_none( ) ) ;
168+ }
169+ }
170+ out_idx += out;
171+
172+ // 如果剩余tokens总数等于0,则退出循环
173+ if self . max_toks == ans. tokens . len ( ) {
174+ break ;
170175 }
171- out_idx += out
172176 }
177+ self . sess . append ( & mut write_back_sessions) ;
173178 ans
174179 }
175180
0 commit comments