1- use super :: { args:: Args , LlamaMeta } ;
1+ use super :: { args:: Args , LlamaBlkWeight , LlamaMeta } ;
22use gguf:: ggml_quants:: {
33 digit_layout:: { types as ty, DigitLayout } ,
44 f16,
@@ -53,18 +53,6 @@ pub trait Operators {
5353 }
5454}
5555
56- #[ derive( Clone , Copy , PartialEq , Eq , Debug ) ]
57- pub enum BlkWeight {
58- AttnNorm ,
59- AttnQKV ,
60- AttnQKVBias ,
61- AttnO ,
62- FfnNorm ,
63- FfnGateInp ,
64- FfnGateUp ,
65- FfnDown ,
66- }
67-
6856pub trait WeightLoader {
6957 type Hardware : Hardware ;
7058 type Weight < ' s > : Deref < Target = [ ByteOf < Self :: Hardware > ] > + ' s
@@ -73,14 +61,14 @@ pub trait WeightLoader {
7361
7462 fn load_blk < ' a > (
7563 & ' a self ,
76- which : BlkWeight ,
64+ which : LlamaBlkWeight ,
7765 iblk : usize ,
7866 queue : & ' a QueueOf < Self :: Hardware > ,
7967 ) -> Self :: Weight < ' a > ;
8068
8169 fn load_moe < ' a > (
8270 & ' a self ,
83- which : BlkWeight ,
71+ which : LlamaBlkWeight ,
8472 iblk : usize ,
8573 iexp : usize ,
8674 queue : & ' a QueueOf < Self :: Hardware > ,
@@ -638,7 +626,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
638626 iblk : usize ,
639627 queue : & ' a QueueOf < W :: Hardware > ,
640628 ) -> Tensor < W :: Weight < ' a > > {
641- let w = self . weights . load_blk ( BlkWeight :: AttnNorm , iblk, queue) ;
629+ let w = self . weights . load_blk ( LlamaBlkWeight :: AttnNorm , iblk, queue) ;
642630 self . norm . clone ( ) . map ( |_| w)
643631 }
644632
@@ -648,7 +636,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
648636 iblk : usize ,
649637 queue : & ' a QueueOf < W :: Hardware > ,
650638 ) -> Tensor < W :: Weight < ' a > > {
651- let w = self . weights . load_blk ( BlkWeight :: AttnQKV , iblk, queue) ;
639+ let w = self . weights . load_blk ( LlamaBlkWeight :: AttnQKV , iblk, queue) ;
652640 self . attn_qkv . clone ( ) . map ( |_| w)
653641 }
654642
@@ -658,7 +646,9 @@ impl<W: WeightLoader> WeightDecorator<W> {
658646 iblk : usize ,
659647 queue : & ' a QueueOf < W :: Hardware > ,
660648 ) -> Tensor < W :: Weight < ' a > > {
661- let w = self . weights . load_blk ( BlkWeight :: AttnQKVBias , iblk, queue) ;
649+ let w = self
650+ . weights
651+ . load_blk ( LlamaBlkWeight :: AttnQKVBias , iblk, queue) ;
662652 self . attn_qkv_bias . clone ( ) . map ( |_| w)
663653 }
664654
@@ -668,7 +658,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
668658 iblk : usize ,
669659 queue : & ' a QueueOf < W :: Hardware > ,
670660 ) -> Tensor < W :: Weight < ' a > > {
671- let w = self . weights . load_blk ( BlkWeight :: AttnO , iblk, queue) ;
661+ let w = self . weights . load_blk ( LlamaBlkWeight :: AttnO , iblk, queue) ;
672662 self . attn_o . clone ( ) . map ( |_| w)
673663 }
674664
@@ -678,7 +668,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
678668 iblk : usize ,
679669 queue : & ' a QueueOf < W :: Hardware > ,
680670 ) -> Tensor < W :: Weight < ' a > > {
681- let w = self . weights . load_blk ( BlkWeight :: FfnNorm , iblk, queue) ;
671+ let w = self . weights . load_blk ( LlamaBlkWeight :: FfnNorm , iblk, queue) ;
682672 self . norm . clone ( ) . map ( |_| w)
683673 }
684674
@@ -688,7 +678,9 @@ impl<W: WeightLoader> WeightDecorator<W> {
688678 iblk : usize ,
689679 queue : & ' a QueueOf < W :: Hardware > ,
690680 ) -> Tensor < W :: Weight < ' a > > {
691- let w = self . weights . load_blk ( BlkWeight :: FfnGateInp , iblk, queue) ;
681+ let w = self
682+ . weights
683+ . load_blk ( LlamaBlkWeight :: FfnGateInp , iblk, queue) ;
692684 self . ffn_gate_inp . clone ( ) . map ( |_| w)
693685 }
694686
@@ -699,7 +691,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
699691 iexp : usize ,
700692 queue : & ' a QueueOf < W :: Hardware > ,
701693 ) -> Tensor < W :: Weight < ' a > > {
702- const WHICH : BlkWeight = BlkWeight :: FfnGateUp ;
694+ const WHICH : LlamaBlkWeight = LlamaBlkWeight :: FfnGateUp ;
703695 let w = if self . is_moe {
704696 self . weights . load_moe ( WHICH , iblk, iexp, queue)
705697 } else {
@@ -715,7 +707,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
715707 iexp : usize ,
716708 queue : & ' a QueueOf < W :: Hardware > ,
717709 ) -> Tensor < W :: Weight < ' a > > {
718- const WHICH : BlkWeight = BlkWeight :: FfnDown ;
710+ const WHICH : LlamaBlkWeight = LlamaBlkWeight :: FfnDown ;
719711 let w = if self . is_moe {
720712 self . weights . load_moe ( WHICH , iblk, iexp, queue)
721713 } else {
0 commit comments