Skip to content

Commit 2e5d211

Browse files
committed
feat: 开始开发 clip 模型结构
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent fc013bf commit 2e5d211

File tree

7 files changed

+245
-6
lines changed

7 files changed

+245
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ itertools = "0.13"
2828
build-script-cfg = "0.0"
2929

3030
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "48d36c5" }
31-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "48892b8", default-features = false }
31+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "1b08473", default-features = false }
3232

3333
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "6846d52" }
3434
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e2ec203" }

models/clip/common-cpu/src/lib.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,41 @@
1+
use clip::{ClipStorage, WeightLoader};
2+
use operators::{common_cpu::Cpu, conv, QueueOf, TopoNode};
3+
use std::marker::PhantomData;
4+
5+
pub struct Operators<N = Cpu>(PhantomData<N>);
6+
7+
pub struct Weights<'w> {
8+
patch_embd_w: &'w [u8],
9+
patch_embd_b: &'w [u8],
10+
}
11+
12+
impl<N> clip::Operators for Operators<N>
13+
where
14+
N: TopoNode<Cpu>,
15+
{
16+
type Hardware = Cpu;
17+
type TopoNode = Cpu;
18+
type Conv = conv::common_cpu::ConvIm2Col;
19+
}
20+
21+
impl<'w> Weights<'w> {
22+
pub fn new(model: &'w ClipStorage<&'w [u8]>) -> Self {
23+
Self {
24+
patch_embd_w: model.patch_embd_w,
25+
patch_embd_b: model.patch_embd_b,
26+
}
27+
}
28+
}
29+
30+
impl WeightLoader for Weights<'_> {
31+
type Hardware = Cpu;
32+
type Weight<'s> = &'s [u8] where Self: 's;
33+
34+
#[inline]
35+
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Weight<'a>; 2] {
36+
[self.patch_embd_w, self.patch_embd_b]
37+
}
38+
}
39+
140
#[cfg(test)]
241
mod test_infer;

models/clip/common-cpu/src/test_infer.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
use clip::{ClipMeta, ClipStorage, Image};
1+
use crate::{Operators, Weights};
2+
use clip::{ClipArgs, ClipMeta, ClipStorage, ClipWorker, Image};
23
use gguf::GGufModel;
4+
use operators::common_cpu::{Cpu, ThisThread};
35
use std::time::Instant;
46
use test_utils::Inference;
57

8+
type Worker<'w> = ClipWorker<Operators, Weights<'w>>;
9+
610
#[test]
711
fn test_infer() {
812
let Some(Inference { model, .. }) = Inference::load() else {
@@ -33,8 +37,38 @@ fn test_infer() {
3337
println!("load image {:?}", time.elapsed());
3438

3539
let time = Instant::now();
36-
let _slices = image
40+
let slices = image
3741
.slice_uhd(9, d_image, d_patch)
3842
.normalize(dt_embd, image_mean, image_std);
3943
println!("slice image {:?}", time.elapsed());
44+
45+
let weights = Weights::new(&storage);
46+
let mut worker = Worker::new(&Cpu, meta.clone(), weights);
47+
48+
let whole = slices.whole();
49+
worker
50+
.launch(
51+
ClipArgs {
52+
raw: whole.to_nchw(),
53+
},
54+
&mut [],
55+
&ThisThread,
56+
)
57+
.unwrap();
58+
59+
let [x, y] = slices.grid();
60+
for i in 0..y {
61+
for j in 0..x {
62+
let patch = slices.patch(j, i);
63+
worker
64+
.launch(
65+
ClipArgs {
66+
raw: patch.to_nchw(),
67+
},
68+
&mut [],
69+
&ThisThread,
70+
)
71+
.unwrap();
72+
}
73+
}
4074
}

models/clip/common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
1010
common.workspace = true
1111
gguf.workspace = true
1212
tensor.workspace = true
13+
operators.workspace = true
1314
itertools.workspace = true
1415
image = "0.25"
1516
rayon = "1.10"

models/clip/common/src/args.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use operators::Hardware;
2+
use tensor::Tensor;
3+
4+
pub struct Args<'a, H: Hardware> {
5+
/// shape: [n, c, h, w]
6+
pub raw: Tensor<&'a [H::Byte]>,
7+
}

models/clip/common/src/compute.rs

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
use super::{args::Args, ClipMeta};
2+
use operators::{
3+
conv::{self, Conv},
4+
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode,
5+
};
6+
use std::ops::{Deref, DerefMut};
7+
use tensor::Tensor;
8+
9+
pub trait Operators {
10+
type Hardware: Hardware;
11+
type TopoNode: TopoNode<Self::Hardware>;
12+
type Conv: Conv<Self::Hardware>;
13+
}
14+
15+
pub trait WeightLoader {
16+
type Hardware: Hardware;
17+
type Weight<'s>: Deref<Target = [ByteOf<Self::Hardware>]> + 's
18+
where
19+
Self: 's;
20+
21+
fn patch_embd<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> [Self::Weight<'a>; 2];
22+
}
23+
24+
pub struct ClipWorker<Ops: Operators, W> {
25+
meta: ClipMeta,
26+
weights: WeightDecorator<W>,
27+
conv: Ops::Conv,
28+
pub debug: bool,
29+
}
30+
31+
impl<Ops: Operators, W> ClipWorker<Ops, W> {
32+
pub fn new(node: &Ops::TopoNode, meta: ClipMeta, weights: W) -> Self {
33+
let processor = node.processor();
34+
Self {
35+
weights: meta.decorator(weights),
36+
meta,
37+
conv: Ops::Conv::new(processor),
38+
debug: true,
39+
}
40+
}
41+
42+
#[inline]
43+
pub const fn meta(&self) -> &ClipMeta {
44+
&self.meta
45+
}
46+
}
47+
48+
impl<Ops, W> ClipWorker<Ops, W>
49+
where
50+
Ops: Operators,
51+
W: WeightLoader<Hardware = Ops::Hardware>,
52+
ByteOf<Ops::Hardware>: 'static,
53+
{
54+
pub fn launch<QA>(
55+
&mut self,
56+
args: Args<Ops::Hardware>,
57+
workspace: &mut [ByteOf<Ops::Hardware>],
58+
queue_alloc: &QA,
59+
) -> Result<(), LaunchError>
60+
where
61+
QA: QueueAlloc<Hardware = Ops::Hardware>,
62+
{
63+
let Args { raw } = args;
64+
let queue = queue_alloc.queue();
65+
66+
let ClipMeta { dt_embd, .. } = self.meta;
67+
68+
let [k, b] = self.weights.patch_embd(queue);
69+
let &[n, _, h, w] = raw.shape() else {
70+
unreachable!()
71+
};
72+
let &[m, _, hk, wk] = k.shape() else {
73+
unreachable!()
74+
};
75+
76+
let mut embd = Tensor::new(dt_embd, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s));
77+
self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)
78+
}
79+
}
80+
81+
#[allow(clippy::too_many_arguments)]
82+
impl<Ops, W> ClipWorker<Ops, W>
83+
where
84+
Ops: Operators,
85+
W: WeightLoader<Hardware = Ops::Hardware>,
86+
{
87+
fn conv<Y, X, W_, B, QA>(
88+
&self,
89+
y: &mut Tensor<Y>,
90+
x: &Tensor<X>,
91+
w: &Tensor<W_>,
92+
b: &Tensor<B>,
93+
workspace: &mut [ByteOf<Ops::Hardware>],
94+
queue_alloc: &QA,
95+
) -> Result<(), LaunchError>
96+
where
97+
Y: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
98+
X: Deref<Target = [ByteOf<Ops::Hardware>]>,
99+
W_: Deref<Target = [ByteOf<Ops::Hardware>]>,
100+
B: Deref<Target = [ByteOf<Ops::Hardware>]>,
101+
QA: QueueAlloc<Hardware = Ops::Hardware>,
102+
{
103+
self.conv.launch(
104+
&conv::Args {
105+
y_layout: y.layout(),
106+
y_base: y.base_mut(),
107+
x_layout: x.layout(),
108+
x_base: x.base(),
109+
w_layout: w.layout(),
110+
w_base: w.base(),
111+
b_layout: b.layout(),
112+
b_base: b.base(),
113+
strides: [self.meta.d_patch; 2],
114+
dilations: [1; 2],
115+
pads: [0; 4],
116+
},
117+
workspace,
118+
queue_alloc,
119+
)
120+
}
121+
}
122+
123+
struct WeightDecorator<W> {
124+
weights: W,
125+
patch_embd_w: Tensor<usize>,
126+
patch_embd_b: Tensor<usize>,
127+
}
128+
129+
impl ClipMeta {
130+
fn decorator<W>(&self, weights: W) -> WeightDecorator<W> {
131+
WeightDecorator {
132+
patch_embd_w: self.patch_embd_w(),
133+
patch_embd_b: self.patch_embd_b(),
134+
weights,
135+
}
136+
}
137+
}
138+
139+
impl<W: WeightLoader> WeightDecorator<W> {
140+
#[inline]
141+
pub fn patch_embd<'a>(&'a self, queue: &'a QueueOf<W::Hardware>) -> [Tensor<W::Weight<'a>>; 2] {
142+
let [w, b] = self.weights.patch_embd(queue);
143+
[
144+
self.patch_embd_w.clone().map(|_| w),
145+
self.patch_embd_b.clone().map(|_| b),
146+
]
147+
}
148+
}

models/clip/common/src/lib.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1+
mod args;
2+
mod compute;
13
mod image;
24
mod storage;
35

46
use gguf::ggml_quants::digit_layout::DigitLayout;
5-
use tensor::Tensor;
67

8+
pub use args::Args as ClipArgs;
9+
pub use compute::{ClipWorker, Operators, WeightLoader};
710
pub use image::{Image, ImageGrid};
811
pub use storage::Storage as ClipStorage;
12+
pub use tensor::Tensor;
13+
pub mod ext {
14+
pub use gguf::{
15+
ext::{utok, Mmap},
16+
ggml_quants,
17+
};
18+
}
919

1020
#[derive(Clone, Debug)]
1121
pub struct ClipMeta {
@@ -67,12 +77,12 @@ impl ClipMeta {
6777
}
6878
}
6979

70-
pub fn patch_embd(&self) -> Tensor<usize> {
80+
pub fn patch_embd_w(&self) -> Tensor<usize> {
7181
let &Self { d, d_patch, .. } = self;
7282
Tensor::new(self.dt_mat, &[d, 3, d_patch, d_patch])
7383
}
7484

75-
pub fn patch_embd_bias(&self) -> Tensor<usize> {
85+
pub fn patch_embd_b(&self) -> Tensor<usize> {
7686
let &Self { d, .. } = self;
7787
Tensor::new(self.dt_bias, &[d])
7888
}

0 commit comments

Comments
 (0)