@@ -23,7 +23,8 @@ use std::{
2323 num:: NonZeroUsize ,
2424 ops:: Deref ,
2525 sync:: {
26- Arc , Barrier , Mutex , RwLock ,
26+ Arc , Barrier , Mutex , OnceLock , RwLock ,
27+ atomic:: AtomicUsize ,
2728 mpsc:: { Receiver , Sender } ,
2829 } ,
2930} ;
@@ -76,23 +77,41 @@ const NTOKS: [usize; 7] = [1, 8, 32, 64, 128, 256, 1024];
7677const CHUNKED_PREFILL_LEN : Option < usize > = Some ( 256 ) ;
7778const MAX_TOKS : usize = 1024 ;
7879
80+ #[ derive( Default ) ]
81+ pub struct Progress {
82+ pub ( crate ) weight_size : OnceLock < usize > ,
83+ pub ( crate ) weight_loaded : AtomicUsize ,
84+ }
85+
7986pub ( crate ) fn engine (
8087 llama : LLaMA < Tensor < & [ u8 ] , 2 > > ,
81- gpus : & [ c_int ] ,
88+ workers : & [ ( c_int , Option < Arc < Progress > > ) ] ,
8289 commands : Receiver < Command > ,
8390 outputs : Sender < Output > ,
8491 use_cuda_graph : bool ,
8592) {
86- if let & [ dev] = gpus {
87- return mono ( llama, Device :: new ( dev) , commands, outputs, use_cuda_graph) ;
93+ if let & [ ( gpu, progress) ] = & workers {
94+ return mono (
95+ llama,
96+ Device :: new ( * gpu) ,
97+ progress. clone ( ) ,
98+ commands,
99+ outputs,
100+ use_cuda_graph,
101+ ) ;
88102 }
89103
90104 #[ cfg( not( nccl) ) ]
91105 unreachable ! ( ) ;
92106
93107 #[ cfg( nccl) ]
94108 {
95- let mut comms = CommunicatorGroup :: new ( gpus) . into_vec ( ) . into_iter ( ) ;
109+ use std:: collections:: HashMap ;
110+
111+ let devlist = workers. iter ( ) . map ( |( gpu, _) | * gpu) . collect :: < Vec < _ > > ( ) ;
112+ let mut workers = workers. iter ( ) . cloned ( ) . collect :: < HashMap < _ , _ > > ( ) ;
113+
114+ let mut comms = CommunicatorGroup :: new ( & devlist) . into_vec ( ) . into_iter ( ) ;
96115 let first = comms. next ( ) . unwrap ( ) ;
97116
98117 let mut llama = llama;
@@ -102,26 +121,28 @@ pub(crate) fn engine(
102121 dist : Distribution {
103122 start : 0 ,
104123 len : 1 ,
105- total : gpus . len ( ) ,
124+ total : devlist . len ( ) ,
106125 } ,
126+ progress : workers. remove ( & first. device ( ) . index ( ) ) . unwrap ( ) ,
107127 config : ModelGroupConfig {
108128 static_model_keys : NTOKS ,
109129 dyn_cache_size : 1 ,
110130 use_cuda_graph,
111131 } ,
112132 max_toks : MAX_TOKS ,
113- barrier : Some ( Arc :: new ( Barrier :: new ( gpus . len ( ) ) ) ) ,
133+ barrier : Some ( Arc :: new ( Barrier :: new ( devlist . len ( ) ) ) ) ,
114134 task_box : Default :: default ( ) ,
115135 chunked_prefill_len : CHUNKED_PREFILL_LEN ,
116136 } ;
117-
118137 std:: thread:: scope ( |s| {
119138 let _threads = comms
120139 . map ( |comm| {
121- let dist = Distribution :: new ( comm. rank ( ) , 1 , gpus. len ( ) ) ;
140+ let dev = comm. device ( ) ;
141+ let dist = Distribution :: new ( comm. rank ( ) , 1 , devlist. len ( ) ) ;
122142 let worker = Worker {
123- dev : comm . device ( ) ,
143+ dev,
124144 dist,
145+ progress : workers. remove ( & dev. index ( ) ) . unwrap ( ) ,
125146 ..worker. clone ( )
126147 } ;
127148 let llama = llama. clone ( ) ;
@@ -139,6 +160,7 @@ pub(crate) fn engine(
139160fn mono (
140161 mut llama : LLaMA < Tensor < & [ u8 ] , 2 > > ,
141162 dev : Device ,
163+ progress : Option < Arc < Progress > > ,
142164 commands : Receiver < Command > ,
143165 outputs : Sender < Output > ,
144166 use_cuda_graph : bool ,
@@ -151,6 +173,7 @@ fn mono(
151173 len : 1 ,
152174 total : 1 ,
153175 } ,
176+ progress,
154177 config : ModelGroupConfig {
155178 static_model_keys : NTOKS ,
156179 dyn_cache_size : 1 ,
@@ -170,6 +193,7 @@ fn mono(
170193struct Worker < T > {
171194 dev : Device ,
172195 dist : Distribution ,
196+ progress : Option < Arc < Progress > > ,
173197 config : ModelGroupConfig < T > ,
174198 max_toks : usize ,
175199 barrier : Option < Arc < Barrier > > ,
@@ -197,6 +221,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
197221 let Self {
198222 dev,
199223 dist,
224+ progress,
200225 config,
201226 max_toks,
202227 barrier,
@@ -210,8 +235,15 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
210235 gpu. apply ( |ctx| {
211236 let mut manager = EngineManager :: new ( chunked_prefill_len, max_toks) ;
212237 let mut handle = handle ( ctx) ;
213- let mut models =
214- ModelGroup :: new ( llama, dist, config, attn, & mut handle, barrier. as_deref ( ) ) ;
238+ let mut models = ModelGroup :: new (
239+ llama,
240+ dist,
241+ progress,
242+ config,
243+ attn,
244+ & mut handle,
245+ barrier. as_deref ( ) ,
246+ ) ;
215247
216248 let mut output_head = OutputHead :: new ( output_head, ctx) ;
217249
@@ -332,6 +364,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
332364 let Self {
333365 dev,
334366 dist,
367+ progress,
335368 config,
336369 max_toks : _max_toks,
337370 barrier,
@@ -345,8 +378,15 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
345378 let barrier = barrier. unwrap ( ) ;
346379 gpu. apply ( |ctx| {
347380 let mut handle = Handle :: with_comm ( ctx, comm) ;
348- let mut models =
349- ModelGroup :: new ( llama, dist, config, attn, & mut handle, Some ( & barrier) ) ;
381+ let mut models = ModelGroup :: new (
382+ llama,
383+ dist,
384+ progress,
385+ config,
386+ attn,
387+ & mut handle,
388+ Some ( & barrier) ,
389+ ) ;
350390
351391 let stream = ctx. stream ( ) ;
352392 loop {
0 commit comments