Skip to content

Commit 3d8b052

Browse files
committed
feat(clip): 支持 resampler pos embd,pos 广播挪到模型内
Signed-off-by: YdrMaster <[email protected]>
1 parent 7a71a4c commit 3d8b052

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

models/clip/common-cpu/src/infer.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::{Operators, Weights};
22
use clip::{ClipArgs, ClipMeta, ClipStorage, ClipWorker, Image, Tensor, D_POS_EMBD};
3-
use gguf::{ggml_quants::digit_layout::types as ty, GGufModel};
3+
use gguf::{
4+
ggml_quants::{digit_layout::types as ty, f16},
5+
GGufModel,
6+
};
47
use operators::{
58
common_cpu::{Cpu, ThisThread},
69
Blob,
@@ -53,22 +56,24 @@ fn test_infer() {
5356
.launch(
5457
ClipArgs {
5558
raw: whole.to_nchw(),
56-
pos: pos70(1, whole.shape(), d_patch).map_slice(),
59+
pos: pos70(whole.shape(), d_patch).map_slice(),
60+
pos_resampler: pos_resampler(3584, whole.shape(), d_patch).map_slice(),
5761
},
5862
&mut [],
5963
&ThisThread,
6064
)
6165
.unwrap();
6266

6367
if let Some(patches) = slices.patches_nchw() {
64-
let &[n, 3, h, w] = patches.shape() else {
68+
let &[_, 3, h, w] = patches.shape() else {
6569
unreachable!()
6670
};
6771
worker
6872
.launch(
6973
ClipArgs {
7074
raw: patches.map_slice(),
71-
pos: pos70(n, [w, h], d_patch).map_slice(),
75+
pos: pos70([w, h], d_patch).map_slice(),
76+
pos_resampler: pos_resampler(3584, [w, h], d_patch).map_slice(),
7277
},
7378
&mut [],
7479
&ThisThread,
@@ -77,7 +82,7 @@ fn test_infer() {
7782
}
7883
}
7984

80-
fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
85+
fn pos70([w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
8186
let w = w / d_patch;
8287
let h = h / d_patch;
8388

@@ -95,15 +100,15 @@ fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
95100
data[i] = (y * D_POS_EMBD + x) as _;
96101
}
97102

98-
ans.broadcast(0, n)
103+
ans
99104
}
100105

101-
fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
106+
fn pos_resampler(d: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
102107
let w = w / d_patch;
103108
let h = h / d_patch;
104109

105-
let mut ans = Tensor::new(ty::F32, &[1, h * w, d]).map(Blob::new);
106-
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f32>() }) else {
110+
let mut ans = Tensor::new(ty::F16, &[1, h * w, d]).map(Blob::new);
111+
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f16>() }) else {
107112
panic!()
108113
};
109114

@@ -118,15 +123,15 @@ fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tens
118123
let d = d / 4;
119124
for i in 0..d {
120125
let (sin, cos) = cache[c * d + i];
121-
data[0 * d..][i] = sin;
122-
data[1 * d..][i] = cos;
126+
data[0 * d..][i] = f16::from_f32(sin);
127+
data[1 * d..][i] = f16::from_f32(cos);
123128
let (sin, cos) = cache[r * d + i];
124-
data[2 * d..][i] = sin;
125-
data[3 * d..][i] = cos;
129+
data[2 * d..][i] = f16::from_f32(sin);
130+
data[3 * d..][i] = f16::from_f32(cos);
126131
}
127132
}
128133

129-
ans.broadcast(0, n)
134+
ans
130135
}
131136

132137
fn sin_cos_cache(max_idx: usize, d: usize, theta: f32) -> Vec<(f32, f32)> {

models/clip/common/src/args.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use tensor::Tensor;
44
pub struct Args<'a, H: Hardware> {
55
/// shape: [n, c, h, w]
66
pub raw: Tensor<&'a [H::Byte]>,
7-
/// shape: [n, h x w]
7+
/// shape: [h x w]
88
pub pos: Tensor<&'a [H::Byte]>,
9+
/// shape: [h x w, resampler.d]
10+
pub pos_resampler: Tensor<&'a [H::Byte]>,
911
}

models/clip/common/src/compute.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ where
145145
QA: QueueAlloc<Hardware = Ops::Hardware>,
146146
{
147147
let time = Instant::now();
148-
let Args { raw, pos } = args;
148+
let Args {
149+
raw,
150+
pos,
151+
pos_resampler,
152+
} = args;
149153
let ClipMeta {
150154
dt,
151155
dt_norm,
@@ -176,14 +180,16 @@ where
176180
let mut embd = Tensor::new(embd_.dt(), embd_.shape()).map(|s| queue_alloc.alloc(s));
177181
self.rearrange(&mut embd, &embd_, workspace, queue_alloc)?;
178182

183+
let &[batch, size, _] = embd.shape() else {
184+
unreachable!()
185+
};
186+
179187
{
180188
let pos_embd = self.weights.pos_embd(queue);
189+
let pos = pos.broadcast(0, batch);
181190
self.add_rows(&mut embd, &pos_embd, &pos, workspace, queue_alloc)?
182191
}
183192

184-
let &[batch, size, _] = embd.shape() else {
185-
unreachable!()
186-
};
187193
let batch_split = vec![size; batch];
188194

189195
let np = batch * size;
@@ -281,6 +287,7 @@ where
281287

282288
let d0 = self.meta.d;
283289
let w = self.meta.mat(d, d0).map(|_| weights.resampler_wkv(queue));
290+
// (np d0) <- (np d) · (d d0)
284291
self.mat_mul(&mut v, &x, (w, None), workspace, queue_alloc)?;
285292

286293
let [w, b] = weights.resampler_ln_q(queue);
@@ -292,9 +299,12 @@ where
292299
let inplace = unsafe { v.map_slice_static() };
293300
self.layer_norm(&mut v, &inplace, ln_v, workspace, queue_alloc)?;
294301

295-
let (buf, workspace) = workspace.split_at_mut(*kv.get());
296-
let pos_embd = Tensor::new(dt, v.shape()).map(|_| buf);
297-
self.add(&mut k, &v, &pos_embd, workspace, queue_alloc)?;
302+
{
303+
let mut k = k.map_slice_mut().tile(0, &[batch, size]);
304+
let v = v.map_slice().tile(0, &[batch, size]);
305+
let pos = pos_resampler.broadcast(0, batch);
306+
self.add(&mut k, &v, &pos, workspace, queue_alloc)?
307+
}
298308

299309
let attn_w = self.meta.mat(d, d);
300310
let attn_b = self.meta.mat(d, 1);

0 commit comments

Comments
 (0)