Skip to content

Commit 514bebf

Browse files
committed
temp(clip): 继续实现 clip 模型,部分实现 transformer
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 4a19e67 commit 514bebf

File tree

5 files changed

+381
-47
lines changed

5 files changed

+381
-47
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fn test_infer() {
2525
println!("{meta:#?}");
2626

2727
let &ClipMeta {
28-
dt_embd,
28+
dt,
2929

3030
d_image,
3131
d_patch,
@@ -42,7 +42,7 @@ fn test_infer() {
4242
let time = Instant::now();
4343
let slices = image
4444
.slice_uhd(9, d_image, d_patch)
45-
.normalize(dt_embd, image_mean, image_std);
45+
.normalize(dt, image_mean, image_std);
4646
println!("slice image {:?}", time.elapsed());
4747

4848
let weights = Weights::new(&storage);

models/clip/common-cpu/src/lib.rs

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use clip::{ClipStorage, WeightLoader};
2-
use operators::{common_cpu::Cpu, conv, QueueOf, TopoNode};
3-
use std::marker::PhantomData;
1+
use clip::{BlkWeight, ClipBlkStorage, ClipStorage, Tensor, WeightLoader};
2+
use operators::{common_cpu::Cpu, conv, ByteOf, QueueOf, TopoNode};
3+
use std::{marker::PhantomData, ops::Deref};
44

55
pub struct Operators<N = Cpu>(PhantomData<N>);
66

@@ -21,7 +21,16 @@ where
2121
type TopoNode = Cpu;
2222
type Conv = conv::common_cpu::ConvIm2Col;
2323
type AddRows = op!(add_rows);
24+
type Rearrange = op!(rearrange);
2425
type LayerNorm = op!(layer_norm);
26+
type MatMul = op!(mat_mul);
27+
28+
fn debug<T>(tensor: &Tensor<T>)
29+
where
30+
T: Deref<Target = [ByteOf<Self::Hardware>]>,
31+
{
32+
println!("{tensor}")
33+
}
2534
}
2635

2736
impl<'w> Weights<'w> {
@@ -32,37 +41,67 @@ impl<'w> Weights<'w> {
3241

3342
impl WeightLoader for Weights<'_> {
3443
type Hardware = Cpu;
35-
type Weight<'s>
44+
type Memory<'s>
3645
= &'s [u8]
3746
where
3847
Self: 's;
3948

49+
fn load_blk(
50+
&self,
51+
which: BlkWeight,
52+
iblk: usize,
53+
_queue: &QueueOf<Self::Hardware>,
54+
) -> [Self::Memory<'_>; 2] {
55+
let ClipBlkStorage {
56+
attn_norm_w,
57+
attn_norm_b,
58+
attn_qkv_w,
59+
attn_qkv_b,
60+
attn_o_w,
61+
attn_o_b,
62+
ffn_norm_w,
63+
ffn_norm_b,
64+
ffn_up_w,
65+
ffn_up_b,
66+
ffn_down_w,
67+
ffn_down_b,
68+
} = &self.0.blocks[iblk];
69+
match which {
70+
BlkWeight::AttnNorm => [attn_norm_w, attn_norm_b],
71+
BlkWeight::AttnQKV => [attn_qkv_w, attn_qkv_b],
72+
BlkWeight::AttnO => [attn_o_w, attn_o_b],
73+
BlkWeight::FfnNorm => [ffn_norm_w, ffn_norm_b],
74+
BlkWeight::FfnUp => [ffn_up_w, ffn_up_b],
75+
BlkWeight::FfnDown => [ffn_down_w, ffn_down_b],
76+
}
77+
}
78+
4079
#[inline]
41-
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Weight<'a>; 2] {
80+
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Memory<'a>; 2] {
4281
[self.0.patch_embd_w, self.0.patch_embd_b]
4382
}
4483

4584
#[inline]
46-
fn pos_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a> {
85+
fn pos_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Memory<'a> {
4786
self.0.pos_embd
4887
}
4988

5089
#[inline]
5190
fn pre_norm<'a>(
5291
&'a self,
5392
_queue: &'a QueueOf<Self::Hardware>,
54-
) -> Option<[Self::Weight<'a>; 2]> {
93+
) -> Option<[Self::Memory<'a>; 2]> {
5594
self.0.pre_norm
5695
}
5796

5897
#[inline]
5998
fn post_norm<'a>(
6099
&'a self,
61100
_queue: &'a QueueOf<Self::Hardware>,
62-
) -> Option<[Self::Weight<'a>; 2]> {
101+
) -> Option<[Self::Memory<'a>; 2]> {
63102
self.0.post_norm
64103
}
65104
}
66105

67106
#[cfg(test)]
68-
mod test_infer;
107+
mod infer;

0 commit comments

Comments
 (0)