Skip to content

Commit 256d219

Browse files
committed
feat(clip): 完成 clip 模型的 transformer 部分
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 2f059ce commit 256d219

File tree

6 files changed

+571
-70
lines changed

6 files changed

+571
-70
lines changed

gguf/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ mod macros {
144144
Err(e) => panic!("failed to read meta: {e:?}"),
145145
}
146146
};
147+
148+
($gguf:expr => (usize) $key:expr) => {
149+
$gguf.get_usize($key).unwrap()
150+
};
151+
($gguf:expr => (usize) $key:expr; $default:expr) => {
152+
match $gguf.get_usize($key) {
153+
Ok(val) => val,
154+
Err(gguf::GGufMetaError::NotExist) => $default,
155+
Err(e) => panic!("failed to read meta: {e:?}"),
156+
}
157+
};
147158
}
148159
#[macro_export]
149160
macro_rules! tensor {
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fn test_infer() {
2525
println!("{meta:#?}");
2626

2727
let &ClipMeta {
28-
dt_embd,
28+
dt,
2929

3030
d_image,
3131
d_patch,
@@ -42,7 +42,7 @@ fn test_infer() {
4242
let time = Instant::now();
4343
let slices = image
4444
.slice_uhd(9, d_image, d_patch)
45-
.normalize(dt_embd, image_mean, image_std);
45+
.normalize(dt, image_mean, image_std);
4646
println!("slice image {:?}", time.elapsed());
4747

4848
let weights = Weights::new(&storage);

models/clip/common-cpu/src/lib.rs

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use clip::{ClipStorage, WeightLoader};
2-
use operators::{common_cpu::Cpu, conv, QueueOf, TopoNode};
3-
use std::marker::PhantomData;
1+
use clip::{BlkWeight, ClipBlkStorage, ClipStorage, Tensor, WeightLoader};
2+
use operators::{common_cpu::Cpu, conv, ByteOf, QueueOf, TopoNode};
3+
use std::{marker::PhantomData, ops::Deref};
44

55
pub struct Operators<N = Cpu>(PhantomData<N>);
66

@@ -22,6 +22,18 @@ where
2222
type Conv = conv::common_cpu::ConvIm2Col;
2323
type AddRows = op!(add_rows);
2424
type LayerNorm = op!(layer_norm);
25+
type MatMul = op!(mat_mul);
26+
type Attention = op!(attention);
27+
type Gelu = op!(gelu);
28+
type Add = op!(add);
29+
type Rearrange = op!(rearrange);
30+
31+
fn debug<T>(tensor: &Tensor<T>)
32+
where
33+
T: Deref<Target = [ByteOf<Self::Hardware>]>,
34+
{
35+
println!("{tensor}")
36+
}
2537
}
2638

2739
impl<'w> Weights<'w> {
@@ -32,37 +44,67 @@ impl<'w> Weights<'w> {
3244

3345
impl WeightLoader for Weights<'_> {
3446
type Hardware = Cpu;
35-
type Weight<'s>
47+
type Memory<'s>
3648
= &'s [u8]
3749
where
3850
Self: 's;
3951

52+
fn load_blk(
53+
&self,
54+
which: BlkWeight,
55+
iblk: usize,
56+
_queue: &QueueOf<Self::Hardware>,
57+
) -> [Self::Memory<'_>; 2] {
58+
let ClipBlkStorage {
59+
attn_norm_w,
60+
attn_norm_b,
61+
attn_qkv_w,
62+
attn_qkv_b,
63+
attn_o_w,
64+
attn_o_b,
65+
ffn_norm_w,
66+
ffn_norm_b,
67+
ffn_up_w,
68+
ffn_up_b,
69+
ffn_down_w,
70+
ffn_down_b,
71+
} = &self.0.blocks[iblk];
72+
match which {
73+
BlkWeight::AttnNorm => [attn_norm_w, attn_norm_b],
74+
BlkWeight::AttnQKV => [attn_qkv_w, attn_qkv_b],
75+
BlkWeight::AttnO => [attn_o_w, attn_o_b],
76+
BlkWeight::FfnNorm => [ffn_norm_w, ffn_norm_b],
77+
BlkWeight::FfnUp => [ffn_up_w, ffn_up_b],
78+
BlkWeight::FfnDown => [ffn_down_w, ffn_down_b],
79+
}
80+
}
81+
4082
#[inline]
41-
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Weight<'a>; 2] {
83+
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Memory<'a>; 2] {
4284
[self.0.patch_embd_w, self.0.patch_embd_b]
4385
}
4486

4587
#[inline]
46-
fn pos_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a> {
88+
fn pos_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Memory<'a> {
4789
self.0.pos_embd
4890
}
4991

5092
#[inline]
5193
fn pre_norm<'a>(
5294
&'a self,
5395
_queue: &'a QueueOf<Self::Hardware>,
54-
) -> Option<[Self::Weight<'a>; 2]> {
96+
) -> Option<[Self::Memory<'a>; 2]> {
5597
self.0.pre_norm
5698
}
5799

58100
#[inline]
59101
fn post_norm<'a>(
60102
&'a self,
61103
_queue: &'a QueueOf<Self::Hardware>,
62-
) -> Option<[Self::Weight<'a>; 2]> {
104+
) -> Option<[Self::Memory<'a>; 2]> {
63105
self.0.post_norm
64106
}
65107
}
66108

67109
#[cfg(test)]
68-
mod test_infer;
110+
mod infer;

0 commit comments

Comments
 (0)