Skip to content

Commit 7a71a4c

Browse files
CearXYdrMaster
authored andcommitted
feat(clip): 实现 resampler pos embd 计算
1 parent 81c8e1e commit 7a71a4c

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

models/clip/common-cpu/src/infer.rs

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,63 @@ fn test_infer() {
7878
}
7979

8080
fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
81-
let pos_w = w / d_patch;
82-
let pos_h = h / d_patch;
81+
let w = w / d_patch;
82+
let h = h / d_patch;
8383

84-
let mut ans = Tensor::new(ty::U32, &[1, pos_w * pos_h])
85-
.broadcast(0, n)
86-
.map(Blob::new);
84+
let mut ans = Tensor::new(ty::U32, &[1, h * w]).map(Blob::new);
8785
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<u32>() }) else {
8886
panic!()
8987
};
9088

91-
for i in 0..pos_h * pos_w {
92-
let y = (i / pos_w) * D_POS_EMBD / pos_h;
93-
let x = (i % pos_w) * D_POS_EMBD / pos_w;
89+
for i in 0..h * w {
90+
let r = i / w;
91+
let c = i % w;
92+
93+
let y = r * D_POS_EMBD / h;
94+
let x = c * D_POS_EMBD / w;
9495
data[i] = (y * D_POS_EMBD + x) as _;
9596
}
9697

97-
ans
98+
ans.broadcast(0, n)
99+
}
100+
101+
fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
102+
let w = w / d_patch;
103+
let h = h / d_patch;
104+
105+
let mut ans = Tensor::new(ty::F32, &[1, h * w, d]).map(Blob::new);
106+
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f32>() }) else {
107+
panic!()
108+
};
109+
110+
assert!(d % 4 == 0);
111+
let cache = sin_cos_cache(w.max(h), d / 4, 1e4);
112+
113+
for i in 0..h * w {
114+
let r = i / w;
115+
let c = i % w;
116+
117+
let data = &mut data[i * d..][..d];
118+
let d = d / 4;
119+
for i in 0..d {
120+
let (sin, cos) = cache[c * d + i];
121+
data[0 * d..][i] = sin;
122+
data[1 * d..][i] = cos;
123+
let (sin, cos) = cache[r * d + i];
124+
data[2 * d..][i] = sin;
125+
data[3 * d..][i] = cos;
126+
}
127+
}
128+
129+
ans.broadcast(0, n)
130+
}
131+
132+
fn sin_cos_cache(max_idx: usize, d: usize, theta: f32) -> Vec<(f32, f32)> {
133+
(0..max_idx * d)
134+
.map(|i| {
135+
let a = (i / d) as f32;
136+
let b = (i % d) as f32;
137+
(a * theta.powf(-(b / d as f32))).sin_cos()
138+
})
139+
.collect()
98140
}

0 commit comments

Comments
 (0)