Skip to content

Commit 88dc961

Browse files
committed
feat(clip): 完善 clip 结构并为 softmax 添加并行加速
Signed-off-by: YdrMaster <[email protected]>
1 parent c9b6bdf commit 88dc961

File tree

7 files changed

+31
-13
lines changed

7 files changed

+31
-13
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ itertools = "0.13"
3838
env_logger = "0.11"
3939
build-script-cfg = "0.0"
4040

41-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "0bd4107", default-features = false }
41+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "8712870", default-features = false }
4242

4343
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4444
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/clip/common-cpu/src/infer.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ fn test_infer() {
5252
let mut worker = Worker::new(&Cpu, meta.clone(), weights);
5353

5454
let whole = slices.whole();
55+
let mut img_embd = meta.projector.img_embd(meta.dt, 1).map(Blob::new);
5556
worker
5657
.launch(
5758
ClipArgs {
59+
img_embd: img_embd.map_slice_mut(),
5860
raw: whole.to_nchw(),
5961
pos: pos70(whole.shape(), d_patch).map_slice(),
6062
pos_resampler: pos_resampler(3584, whole.shape(), d_patch).map_slice(),
@@ -65,12 +67,14 @@ fn test_infer() {
6567
.unwrap();
6668

6769
if let Some(patches) = slices.patches_nchw() {
68-
let &[_, 3, h, w] = patches.shape() else {
70+
let &[batch, 3, h, w] = patches.shape() else {
6971
unreachable!()
7072
};
73+
let mut img_embd = meta.projector.img_embd(meta.dt, batch).map(Blob::new);
7174
worker
7275
.launch(
7376
ClipArgs {
77+
img_embd: img_embd.map_slice_mut(),
7478
raw: patches.map_slice(),
7579
pos: pos70([w, h], d_patch).map_slice(),
7680
pos_resampler: pos_resampler(3584, [w, h], d_patch).map_slice(),

models/clip/common/src/args.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
use tensor::Tensor;
33

44
pub struct Args<'a, H: Hardware> {
5+
/// shape: [batch x projector_dp, projector_d]
6+
pub img_embd: Tensor<&'a mut [H::Byte]>,
57
/// shape: [n, c, h, w]
68
pub raw: Tensor<&'a [H::Byte]>,
79
/// shape: [h x w]

models/clip/common/src/compute.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ where
149149
{
150150
let time = Instant::now();
151151
let Args {
152+
img_embd: proj_q,
152153
raw,
153154
pos,
154155
pos_resampler,
@@ -317,10 +318,7 @@ where
317318
let [w, b] = weights.resampler_attn_o(queue);
318319
let attn_o = (attn_w.clone().map(|_| w), Some(attn_b.clone().map(|_| b)));
319320

320-
let qo = Tensor::new(dt, &[batch * dq, d]);
321-
322-
let (buf, workspace) = workspace.split_at_mut(*qo.get());
323-
let mut q_ = qo.clone().map(|_| buf);
321+
let mut q_ = proj_q;
324322
{
325323
let mut q_ = q_.map_slice_mut().tile(0, &[batch, dq]);
326324
{
@@ -363,18 +361,19 @@ where
363361
}
364362
let o = q_;
365363

366-
let (buf, workspace) = workspace.split_at_mut(*qo.get());
367-
let mut o_ = qo.map(|_| buf);
364+
let o_ = Tensor::new(o.dt(), o.shape());
365+
let (buf, workspace) = workspace.split_at_mut(*o_.get());
366+
let mut o_ = o_.map(|_| buf);
368367
self.mat_mul(&mut o_, &o, attn_o, workspace, queue_alloc)?;
369368

370369
let [w, b] = weights.resampler_ln_post(queue);
371370
let ln_post = [ln.clone().map(|_| w), ln.clone().map(|_| b)];
372371
let inplace = unsafe { o_.map_slice_static() };
373372
self.layer_norm(&mut o_, &inplace, ln_post, workspace, queue_alloc)?;
374373

375-
let mut out = o;
374+
let mut img_embd = o;
376375
let w = attn_w.map(|_| weights.resampler_proj(queue));
377-
self.mat_mul(&mut out, &o_, (w, None), workspace, queue_alloc)?
376+
self.mat_mul(&mut img_embd, &o_, (w, None), workspace, queue_alloc)?
378377
}
379378
}
380379

models/clip/common/src/projector/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub(crate) mod resampler;
22

3-
use gguf::{GGufMetaMapExt, GGufModel};
3+
use gguf::{ggml_quants::digit_layout::DigitLayout, GGufMetaMapExt, GGufModel};
4+
use tensor::Tensor;
45

56
#[derive(Clone, Debug)]
67
pub enum ProjectorMeta {
@@ -14,6 +15,12 @@ impl ProjectorMeta {
1415
projector => todo!("unsupported projector type: {projector}"),
1516
}
1617
}
18+
19+
pub fn img_embd(&self, dt: DigitLayout, batch: usize) -> Tensor<usize> {
20+
match self {
21+
ProjectorMeta::Resampler(meta) => meta.img_embd(dt, batch),
22+
}
23+
}
1724
}
1825

1926
#[derive(Clone)]

models/clip/common/src/projector/resampler.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use gguf::{tensor, GGufMetaMapExt, GGufModel};
1+
use gguf::{ggml_quants::digit_layout::DigitLayout, tensor, GGufMetaMapExt, GGufModel};
2+
use tensor::Tensor;
23

34
#[derive(Clone, Debug)]
45
pub struct Meta {
@@ -23,6 +24,11 @@ impl Meta {
2324
version => todo!("Unsupported MiniCPM version: {version}"),
2425
}
2526
}
27+
28+
#[inline]
29+
pub fn img_embd(&self, dt: DigitLayout, batch: usize) -> Tensor<usize> {
30+
Tensor::new(dt, &[batch * self.dq, self.d])
31+
}
2632
}
2733

2834
#[derive(Clone)]

tensor/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl<T> Tensor<T> {
8484

8585
let merged = self
8686
.layout
87-
.merge_be(0, self.layout.ndim())
87+
.merge_free(0, self.layout.ndim())
8888
.expect("dense tensor is castable");
8989
let &[d] = merged.shape() else { unreachable!() };
9090
let &[s] = merged.strides() else {

0 commit comments

Comments
 (0)