Skip to content

Commit 357bf29

Browse files
committed
feat(clip): 添加 pos 计算
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 9bfe2a5 commit 357bf29

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

models/clip/common-cpu/src/test_infer.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use crate::{Operators, Weights};
2-
use clip::{ClipArgs, ClipMeta, ClipStorage, ClipWorker, Image};
3-
use gguf::GGufModel;
4-
use operators::common_cpu::{Cpu, ThisThread};
2+
use clip::{ClipArgs, ClipMeta, ClipStorage, ClipWorker, Image, Tensor};
3+
use gguf::{ggml_quants::digit_layout::types as ty, GGufModel};
4+
use operators::{
5+
common_cpu::{Cpu, ThisThread},
6+
Blob,
7+
};
58
use std::time::Instant;
69
use test_utils::Inference;
710

@@ -50,21 +53,51 @@ fn test_infer() {
5053
.launch(
5154
ClipArgs {
5255
raw: whole.to_nchw(),
56+
pos: pos70(whole.shape(), d_patch).map_slice(),
5357
},
5458
&mut [],
5559
&ThisThread,
5660
)
5761
.unwrap();
5862

5963
if let Some(patches) = slices.patches_nchw() {
64+
let &[_, 3, h, w] = patches.shape() else {
65+
unreachable!()
66+
};
6067
worker
6168
.launch(
6269
ClipArgs {
6370
raw: patches.map_slice(),
71+
pos: pos70([w, h], d_patch).map_slice(),
6472
},
6573
&mut [],
6674
&ThisThread,
6775
)
6876
.unwrap();
6977
}
7078
}
79+
80+
fn pos70([w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
81+
let pos_w = w / d_patch;
82+
let pos_h = h / d_patch;
83+
let mut bucket_corrds_h = [0; 70];
84+
let mut bucket_corrds_w = [0; 70];
85+
for i in 0..pos_w {
86+
bucket_corrds_w[i] = ((70 * i) as f64 / pos_w as f64) as _;
87+
}
88+
for i in 0..pos_h {
89+
bucket_corrds_h[i] = ((70 * i) as f64 / pos_h as f64) as _;
90+
}
91+
92+
let mut ans = Tensor::new(ty::U32, &[pos_w * pos_h]).map(Blob::new);
93+
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<u32>() }) else {
94+
panic!()
95+
};
96+
97+
let f = |i, d| ((70 * i) as f64 / d as f64) as u32;
98+
for i in 0..pos_h * pos_w {
99+
data[i] = f(i / pos_w, pos_h) * 70 + f(i % pos_w, pos_w);
100+
}
101+
102+
ans
103+
}

models/clip/common/src/args.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ use tensor::Tensor;
44
pub struct Args<'a, H: Hardware> {
55
/// shape: [n, c, h, w]
66
pub raw: Tensor<&'a [H::Byte]>,
7+
/// shape: [h x w]
8+
pub pos: Tensor<&'a [H::Byte]>,
79
}

models/clip/common/src/compute.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ where
6464
QA: QueueAlloc<Hardware = Ops::Hardware>,
6565
{
6666
let time = Instant::now();
67-
let Args { raw } = args;
67+
let Args { raw, .. } = args;
6868
let queue = queue_alloc.queue();
6969

7070
let ClipMeta { dt_embd, .. } = self.meta;
@@ -80,6 +80,8 @@ where
8080
let mut embd = Tensor::new(dt_embd, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s));
8181
self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)?;
8282

83+
let _embd = embd.merge(2..4).unwrap().transpose(&[2, 1]);
84+
8385
if self.debug {
8486
println!("encode {n} x {h} x {w} image in {:?}", time.elapsed());
8587
}

0 commit comments

Comments
 (0)