Skip to content

Commit 0255b79

Browse files
CearXYdrMaster
authored andcommitted
style: cleanup resampler-pos
1 parent 81c8e1e commit 0255b79

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

models/clip/common-cpu/src/infer.rs

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,163 @@ fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
9696

9797
ans
9898
}
99+
100+
#[cfg(test)]
101+
mod test_pos_embd {
102+
use super::*;
103+
104+
mod c_like {
105+
use super::*;
106+
pub(super) fn pos_resampler(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
107+
let d = 3584;
108+
let pos_w = w / d_patch;
109+
let pos_h = h / d_patch;
110+
111+
let mut ans = Tensor::new(ty::F32, &[1, pos_w * pos_h, d])
112+
.broadcast(0, n)
113+
.map(Blob::new);
114+
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f32>() }) else {
115+
panic!()
116+
};
117+
118+
let pos_embed_t = get_2d_sincos_pos_embed(d, (pos_w, pos_h));
119+
120+
for i in 0..pos_w * pos_h {
121+
for j in 0..d {
122+
data[i * d + j] = pos_embed_t[i][j];
123+
}
124+
}
125+
ans
126+
}
127+
128+
fn get_2d_sincos_pos_embed(embed_dim: usize, image_size: (usize, usize)) -> Vec<Vec<f32>> {
129+
let (grid_h_size, grid_w_size) = image_size;
130+
131+
let mut grid_h: Vec<f32> = (0..grid_h_size).map(|i| i as f32).collect();
132+
let mut grid_w: Vec<f32> = (0..grid_w_size).map(|i| i as f32).collect();
133+
134+
let mut grid: Vec<Vec<f32>> = vec![vec![0.0; grid_w_size]; grid_h_size];
135+
for h in 0..grid_h_size {
136+
for w in 0..grid_w_size {
137+
grid[h][w] = grid_w[w];
138+
}
139+
}
140+
141+
let mut grid_2d: Vec<Vec<Vec<f32>>> = vec![grid.clone(), grid.clone()];
142+
for h in 0..grid_h_size {
143+
for w in 0..grid_w_size {
144+
grid_2d[0][h][w] = grid_h[h];
145+
grid_2d[1][h][w] = grid_w[w];
146+
}
147+
}
148+
149+
let pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
150+
151+
let (H, W) = image_size;
152+
let mut pos_embed_2d: Vec<Vec<f32>> = vec![vec![0.0; embed_dim]; H * W];
153+
for h in 0..H {
154+
for w in 0..W {
155+
pos_embed_2d[w * H + h] = pos_embed_3d[h][w].clone();
156+
}
157+
}
158+
159+
pos_embed_2d
160+
}
161+
162+
fn get_2d_sincos_pos_embed_from_grid(
163+
embed_dim: usize,
164+
grid: Vec<Vec<Vec<f32>>>,
165+
) -> Vec<Vec<Vec<f32>>> {
166+
assert!(embed_dim % 2 == 0);
167+
168+
let emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0].clone()); // (H, W, D/2)
169+
let emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1].clone()); // (H, W, D/2)
170+
171+
let H = emb_h.len();
172+
let W = emb_h[0].len();
173+
let mut emb: Vec<Vec<Vec<f32>>> = vec![vec![vec![0.0; embed_dim]; W]; H];
174+
175+
for h in 0..H {
176+
for w in 0..W {
177+
for d in 0..(embed_dim / 2) {
178+
emb[h][w][d] = emb_h[h][w][d];
179+
emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
180+
}
181+
}
182+
}
183+
184+
emb
185+
}
186+
187+
fn get_1d_sincos_pos_embed_from_grid_new(
188+
embed_dim: usize,
189+
pos: Vec<Vec<f32>>,
190+
) -> Vec<Vec<Vec<f32>>> {
191+
assert!(embed_dim % 2 == 0);
192+
let H = pos.len();
193+
let W = pos[0].len();
194+
195+
let mut omega: Vec<f32> = (0..embed_dim / 2)
196+
.map(|i| 1.0 / 10000.0f32.powi(i as i32 / (embed_dim / 2) as i32))
197+
.collect();
198+
199+
let mut emb: Vec<Vec<Vec<f32>>> = vec![vec![vec![0.0; embed_dim]; W]; H];
200+
for h in 0..H {
201+
for w in 0..W {
202+
for d in 0..(embed_dim / 2) {
203+
let out_value = pos[h][w] * omega[d];
204+
emb[h][w][d] = out_value.sin();
205+
emb[h][w][d + embed_dim / 2] = out_value.cos();
206+
}
207+
}
208+
}
209+
210+
emb
211+
}
212+
}
213+
214+
mod rust_style {
215+
use super::*;
216+
pub(super) fn pos_resampler(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
217+
let d = 3584;
218+
assert!(d % 4 == 0);
219+
220+
let pos_w = w / d_patch;
221+
let pos_h = h / d_patch;
222+
223+
let mut ans = Tensor::new(ty::F32, &[1, pos_w * pos_h, d]).map(Blob::new);
224+
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f32>() }) else {
225+
panic!()
226+
};
227+
set_2d_sincos_pos_embed(data, d, (pos_w, pos_h));
228+
229+
ans.broadcast(0, n)
230+
}
231+
232+
fn set_2d_sincos_pos_embed(data: &mut [f32], d: usize, (h, w): (usize, usize)) {
233+
for r in 0..h {
234+
for c in 0..w {
235+
let data = &mut data[(c * h + r) * d..][..d];
236+
let d = d / 4;
237+
238+
for i in 0..d {
239+
let (sin, cos) = (r as f32).sin_cos();
240+
data[0 * d..][i] = sin;
241+
data[1 * d..][i] = cos;
242+
243+
let (sin, cos) = (c as f32).sin_cos();
244+
data[2 * d..][i] = sin;
245+
data[3 * d..][i] = cos;
246+
}
247+
}
248+
}
249+
}
250+
}
251+
252+
#[test]
253+
fn test_eq() {
254+
let a = c_like::pos_resampler(4, [336, 224], 14).take();
255+
let b = rust_style::pos_resampler(4, [336, 224], 14).take();
256+
assert_eq!(&*a, &*b);
257+
}
258+
}

0 commit comments

Comments
 (0)