11#![ cfg( driver_detected) ]
22
3- use llama:: { BlkWeight , Contiguous , LlamaBlkStorage , LlamaStorage , Tensor , WeightLoader } ;
3+ use common:: { Contiguous , Slab } ;
4+ use llama:: { BlkWeight , LlamaBlkStorage , LlamaStorage , Tensor , WeightLoader } ;
5+ use log:: trace;
46use operators:: {
57 all_reduce:: { AllReduce , NonAllReduce } ,
6- cuda:: { memcpy_d2h, CurrentCtx , DevByte , DevMem , Event , HostMem , Stream } ,
8+ cuda:: { memcpy_d2h, AsRaw , CurrentCtx , DevByte , DevMem , Event , HostMem , Stream } ,
79 nvidia_gpu:: Gpu ,
810 random_sample:: nvidia_gpu:: Operator as RandomSampleGpu ,
911 rearrange:: nvidia_gpu:: Operator as Rearrange ,
@@ -15,6 +17,7 @@ use std::{
1517 mem:: replace,
1618 ops:: { Deref , RangeBounds } ,
1719 rc:: Rc ,
20+ time:: Instant ,
1821} ;
1922
2023pub struct Operators < N = Gpu , R = NonAllReduce < Gpu , Rearrange > > ( PhantomData < ( N , R ) > ) ;
@@ -157,11 +160,14 @@ impl<'blk> Weights<'blk> {
157160 ) -> Self {
158161 assert ! ( pool_size > 0 ) ;
159162 let stream = Rc :: new ( ctx. stream ( ) ) ;
163+ let igpu = unsafe { ctx. dev ( ) . as_raw ( ) } ;
164+ let mut slab = Slab :: new ( ) ;
160165 let blks = if pool_size < model. meta . nblk {
161166 let mut blks_host = model. blocks [ 0 ]
162167 . as_ref ( )
163168 . map ( |_| Vec :: with_capacity ( model. meta . nblk ) ) ;
164- for blk in model. blocks . iter ( ) {
169+ for ( iblk, blk) in model. blocks . iter ( ) . enumerate ( ) {
170+ let time = Instant :: now ( ) ;
165171 let blk = blk
166172 . distribute ( & model. meta , range. clone ( ) , count, |len| {
167173 ctx. malloc_host :: < u8 > ( len)
@@ -188,6 +194,7 @@ impl<'blk> Weights<'blk> {
188194 ffn_gate_up
189195 ffn_down
190196 }
197+ trace ! ( "blk{iblk} loaded to gpu{igpu} in {:?}" , time. elapsed( ) )
191198 }
192199 blks_host. map ( |vec| {
193200 let roll_cache = vec
@@ -206,18 +213,26 @@ impl<'blk> Weights<'blk> {
206213 let mut blks_dev = model. blocks [ 0 ]
207214 . as_ref ( )
208215 . map ( |_| Vec :: with_capacity ( model. meta . nblk ) ) ;
209- for blk in & model. blocks {
210- let blk = blk. distribute ( & model. meta , range. clone ( ) , count, |len| {
211- ctx. malloc_host :: < u8 > ( len)
216+ for ( iblk, blk) in model. blocks . iter ( ) . enumerate ( ) {
217+ let blk = blk. distribute ( & model. meta , range. clone ( ) , count, |size| {
218+ slab. take ( & size)
219+ . unwrap_or_else ( || ctx. malloc_host :: < u8 > ( size) )
212220 } ) ;
213221 let loader = loader
214222 . get_or_insert_with ( || blk. as_ref ( ) . map ( |s| H2DLoader :: new ( s. len ( ) , & stream) ) ) ;
215223
216224 macro_rules! load {
217225 ( $( $ident: ident ) + ) => {
218- $( { blks_dev. $ident. push( loader. $ident. load( blk. $ident, & stream) ) ; } ) +
226+ $(
227+ let ( dev, host) = loader. $ident. load( blk. $ident, & stream) ;
228+ if let Some ( host) = host {
229+ slab. put( host. len( ) , host)
230+ }
231+ blks_dev. $ident. push( dev) ;
232+ ) +
219233 } ;
220234 }
235+ let time = Instant :: now ( ) ;
221236 load ! {
222237 attn_norm
223238 attn_qkv
@@ -226,6 +241,7 @@ impl<'blk> Weights<'blk> {
226241 ffn_gate_up
227242 ffn_down
228243 }
244+ trace ! ( "blk{iblk} loaded to gpu{igpu} in {:?}" , time. elapsed( ) )
229245 }
230246 blks_dev. map ( |vec| Cache :: Static ( vec. into_boxed_slice ( ) ) )
231247 } ;
@@ -253,15 +269,25 @@ impl<'ctx> H2DLoader<'ctx> {
253269 }
254270 }
255271
256- fn load ( & mut self , host : Contiguous < HostMem < ' ctx > > , stream : & Stream < ' ctx > ) -> DevMem < ' ctx > {
272+ fn load (
273+ & mut self ,
274+ host : Contiguous < HostMem < ' ctx > > ,
275+ stream : & Stream < ' ctx > ,
276+ ) -> ( DevMem < ' ctx > , Option < HostMem < ' ctx > > ) {
257277 self . event . synchronize ( ) ;
258- match host {
259- Contiguous :: Borrowed ( host) => self . host . copy_from_slice ( host) ,
260- Contiguous :: Owned ( host) => self . host = host,
278+ let cache = match host {
279+ Contiguous :: Borrowed ( host) => {
280+ self . host . copy_from_slice ( host) ;
281+ None
282+ }
283+ Contiguous :: Owned ( host) => Some ( replace ( & mut self . host , host) ) ,
261284 } ;
262285 stream. memcpy_h2d ( & mut self . dev , & self . host ) ;
263286 self . event = stream. record ( ) ;
264- replace ( & mut self . dev , stream. malloc :: < u8 > ( self . host . len ( ) ) )
287+ (
288+ replace ( & mut self . dev , stream. malloc :: < u8 > ( self . host . len ( ) ) ) ,
289+ cache,
290+ )
265291 }
266292}
267293
0 commit comments