@@ -54,17 +54,19 @@ pub enum BlkWeight {
5454
5555pub trait WeightLoader {
5656 type Hardware : Hardware ;
57- type Memory : Deref < Target = [ ByteOf < Self :: Hardware > ] > ;
57+ type Memory < ' s > : Deref < Target = [ ByteOf < Self :: Hardware > ] > + ' s
58+ where
59+ Self : ' s ;
5860
5961 fn load_blk (
6062 & self ,
6163 which : BlkWeight ,
6264 iblk : usize ,
6365 queue : & QueueOf < Self :: Hardware > ,
64- ) -> Self :: Memory ;
66+ ) -> Self :: Memory < ' _ > ;
6567
66- fn output_norm ( & self , queue : & QueueOf < Self :: Hardware > ) -> Self :: Memory ;
67- fn output ( & self , queue : & QueueOf < Self :: Hardware > ) -> Self :: Memory ;
68+ fn output_norm ( & self , queue : & QueueOf < Self :: Hardware > ) -> Self :: Memory < ' _ > ;
69+ fn output ( & self , queue : & QueueOf < Self :: Hardware > ) -> Self :: Memory < ' _ > ;
6870}
6971
7072pub struct LlamaWorker < Ops : Operators , W > {
@@ -544,60 +546,60 @@ impl LlamaMeta {
544546
545547impl < W : WeightLoader > WeightDecorator < W > {
546548 #[ inline]
547- pub fn attn_norm ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
549+ pub fn attn_norm ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
548550 combine (
549551 & self . attn_norm ,
550552 self . weights . load_blk ( BlkWeight :: AttnNorm , iblk, queue) ,
551553 )
552554 }
553555
554556 #[ inline]
555- pub fn attn_qkv ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
557+ pub fn attn_qkv ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
556558 combine (
557559 & self . attn_qkv ,
558560 self . weights . load_blk ( BlkWeight :: AttnQKV , iblk, queue) ,
559561 )
560562 }
561563
562564 #[ inline]
563- pub fn attn_o ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
565+ pub fn attn_o ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
564566 combine (
565567 & self . attn_o ,
566568 self . weights . load_blk ( BlkWeight :: AttnO , iblk, queue) ,
567569 )
568570 }
569571
570572 #[ inline]
571- pub fn ffn_norm ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
573+ pub fn ffn_norm ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
572574 combine (
573575 & self . ffn_norm ,
574576 self . weights . load_blk ( BlkWeight :: FfnNorm , iblk, queue) ,
575577 )
576578 }
577579
578580 #[ inline]
579- pub fn ffn_gate_up ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
581+ pub fn ffn_gate_up ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
580582 combine (
581583 & self . ffn_gate_up ,
582584 self . weights . load_blk ( BlkWeight :: FfnGateUp , iblk, queue) ,
583585 )
584586 }
585587
586588 #[ inline]
587- pub fn ffn_down ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
589+ pub fn ffn_down ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
588590 combine (
589591 & self . ffn_down ,
590592 self . weights . load_blk ( BlkWeight :: FfnDown , iblk, queue) ,
591593 )
592594 }
593595
594596 #[ inline]
595- pub fn output_norm ( & self , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
597+ pub fn output_norm ( & self , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
596598 combine ( & self . output_norm , self . weights . output_norm ( queue) )
597599 }
598600
599601 #[ inline]
600- pub fn output ( & self , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory > {
602+ pub fn output ( & self , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
601603 combine ( & self . output , self . weights . output ( queue) )
602604 }
603605}
0 commit comments