Skip to content

Commit d80b378

Browse files
committed
fix(clip): 与 pytorch 对齐精度
Signed-off-by: YdrMaster <[email protected]>
1 parent 88740f9 commit d80b378

File tree

4 files changed

+62
-43
lines changed

4 files changed

+62
-43
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ itertools = "0.13"
3838
env_logger = "0.11"
3939
build-script-cfg = "0.0"
4040

41-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "61789f7", default-features = false }
41+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "0bd4107", default-features = false }
4242

4343
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4444
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/clip/common-cpu/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ impl WeightLoader for Weights<'_> {
165165
}
166166
}
167167

168+
fn resampler_ln_post<'a>(
169+
&'a self,
170+
_queue: &'a QueueOf<Self::Hardware>,
171+
) -> [Self::Memory<'a>; 2] {
172+
match &self.0.projector {
173+
ProjectorStroage::Resampler(storage) => storage.ln_post,
174+
}
175+
}
176+
168177
fn resampler_proj<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Memory<'a> {
169178
match &self.0.projector {
170179
ProjectorStroage::Resampler(storage) => storage.proj,

models/clip/common/src/compute.rs

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ pub trait WeightLoader {
7272
fn resampler_attn_k<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> [Self::Memory<'a>; 2];
7373
fn resampler_attn_v<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> [Self::Memory<'a>; 2];
7474
fn resampler_attn_o<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> [Self::Memory<'a>; 2];
75+
fn resampler_ln_post<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>)
76+
-> [Self::Memory<'a>; 2];
7577
fn resampler_proj<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Memory<'a>;
7678
}
7779

@@ -268,7 +270,7 @@ where
268270

269271
let weights = &self.weights.weights;
270272
let q0 = Tensor::new(dt, &[dq, d]).map(|_| weights.resampler_q(queue));
271-
let ln_qkv = Tensor::new(dt_norm, &[d]);
273+
let ln = Tensor::new(dt_norm, &[d]);
272274

273275
let q = Tensor::new(dt, q0.shape());
274276
let kv = Tensor::new(dt, &[np, d]);
@@ -285,11 +287,11 @@ where
285287
self.mat_mul(&mut v, &x, (w, None), workspace, queue_alloc)?;
286288

287289
let [w, b] = weights.resampler_ln_q(queue);
288-
let ln_q = [ln_qkv.clone().map(|_| w), ln_qkv.clone().map(|_| b)];
290+
let ln_q = [ln.clone().map(|_| w), ln.clone().map(|_| b)];
289291
self.layer_norm(&mut q, &q0, ln_q, workspace, queue_alloc)?;
290292

291293
let [w, b] = weights.resampler_ln_kv(queue);
292-
let ln_v = [ln_qkv.clone().map(|_| w), ln_qkv.clone().map(|_| b)];
294+
let ln_v = [ln.clone().map(|_| w), ln.clone().map(|_| b)];
293295
let inplace = unsafe { v.map_slice_static() };
294296
self.layer_norm(&mut v, &inplace, ln_v, workspace, queue_alloc)?;
295297

@@ -315,34 +317,34 @@ where
315317
let [w, b] = weights.resampler_attn_o(queue);
316318
let attn_o = (attn_w.clone().map(|_| w), Some(attn_b.clone().map(|_| b)));
317319

318-
let q_ = Tensor::new(dt, &[batch, dq, d]);
319-
let k_ = Tensor::new(dt, &[np, d]);
320-
let v_ = Tensor::new(dt, &[np, d]);
321-
let o_ = Tensor::new(dt, &[batch * dq, d]);
320+
let qo = Tensor::new(dt, &[batch * dq, d]);
322321

323-
let (buf, workspace) = workspace.split_at_mut(*q_.get());
324-
let mut q_ = q_.map(|_| buf);
322+
let (buf, workspace) = workspace.split_at_mut(*qo.get());
323+
let mut q_ = qo.clone().map(|_| buf);
325324
{
326-
let mut q_ = q_.map_slice_mut().index(0, 0);
327-
self.mat_mul(&mut q_, &q, attn_q, workspace, queue_alloc)?
328-
}
329-
if batch > 1 {
330-
split!(q_ => q0, q1; [1, batch - 1] @ 0);
331-
let q0 = q0.broadcast(0, batch - 1);
332-
let mut q1 = q1;
333-
self.rearrange(&mut q1, &q0, workspace, queue_alloc)?
325+
let mut q_ = q_.map_slice_mut().tile(0, &[batch, dq]);
326+
{
327+
let mut q_ = q_.map_slice_mut().index(0, 0);
328+
self.mat_mul(&mut q_, &q, attn_q, workspace, queue_alloc)?
329+
}
330+
if batch > 1 {
331+
split!(q_ => q0, q1; [1, batch - 1] @ 0);
332+
let q0 = q0.broadcast(0, batch - 1);
333+
let mut q1 = q1;
334+
self.rearrange(&mut q1, &q0, workspace, queue_alloc)?
335+
}
334336
}
335-
let mut q_ = q_.merge(0..2).unwrap();
337+
{
338+
let kv = Tensor::new(dt, &[np, d]);
336339

337-
let (buf, workspace) = workspace.split_at_mut(*k_.get());
338-
let mut k_ = k_.map(|_| buf);
339-
self.mat_mul(&mut k_, &k, attn_k, workspace, queue_alloc)?;
340+
let (buf, workspace) = workspace.split_at_mut(*kv.get());
341+
let mut k_ = kv.clone().map(|_| buf);
342+
self.mat_mul(&mut k_, &k, attn_k, workspace, queue_alloc)?;
340343

341-
let (buf, workspace) = workspace.split_at_mut(*v_.get());
342-
let mut v_ = v_.map(|_| buf);
343-
self.mat_mul(&mut v_, &v, attn_v, workspace, queue_alloc)?;
344+
let (buf, workspace) = workspace.split_at_mut(*kv.get());
345+
let mut v_ = kv.map(|_| buf);
346+
self.mat_mul(&mut v_, &v, attn_v, workspace, queue_alloc)?;
344347

345-
{
346348
let nh_dh = &[d / dh, dh];
347349
let q = q_.map_slice_mut().tile(1, nh_dh).transpose(&[1, 0]);
348350
let k = k_.tile(1, nh_dh).transpose(&[1, 0]);
@@ -361,10 +363,15 @@ where
361363
}
362364
let o = q_;
363365

364-
let (buf, workspace) = workspace.split_at_mut(*o_.get());
365-
let mut o_ = o_.map(|_| buf);
366+
let (buf, workspace) = workspace.split_at_mut(*qo.get());
367+
let mut o_ = qo.map(|_| buf);
366368
self.mat_mul(&mut o_, &o, attn_o, workspace, queue_alloc)?;
367369

370+
let [w, b] = weights.resampler_ln_post(queue);
371+
let ln_post = [ln.clone().map(|_| w), ln.clone().map(|_| b)];
372+
let inplace = unsafe { o_.map_slice_static() };
373+
self.layer_norm(&mut o_, &inplace, ln_post, workspace, queue_alloc)?;
374+
368375
let mut out = o;
369376
let w = attn_w.map(|_| weights.resampler_proj(queue));
370377
self.mat_mul(&mut out, &o_, (w, None), workspace, queue_alloc)?

models/clip/common/src/projector/resampler.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,31 @@ pub struct Storage<T> {
3535
pub attn_k: [T; 2],
3636
pub attn_v: [T; 2],
3737
pub attn_o: [T; 2],
38+
pub ln_post: [T; 2],
3839
pub proj: T,
3940
}
4041

4142
impl<'a> Storage<&'a [u8]> {
4243
#[rustfmt::skip]
4344
pub fn from_gguf(gguf: &GGufModel<'a>) -> Self {
4445
Self {
45-
wkv : tensor![gguf => "resampler.kv.weight" ].data ,
46-
q : tensor![gguf => "resampler.query" ].data ,
47-
ln_q : [tensor![gguf => "resampler.ln_q.weight" ].data ,
48-
tensor![gguf => "resampler.ln_q.bias" ].data],
49-
ln_kv : [tensor![gguf => "resampler.ln_kv.weight" ].data ,
50-
tensor![gguf => "resampler.ln_kv.bias" ].data],
51-
attn_q: [tensor![gguf => "resampler.attn.q.weight" ].data ,
52-
tensor![gguf => "resampler.attn.q.bias" ].data],
53-
attn_k: [tensor![gguf => "resampler.attn.k.weight" ].data ,
54-
tensor![gguf => "resampler.attn.k.bias" ].data],
55-
attn_v: [tensor![gguf => "resampler.attn.v.weight" ].data ,
56-
tensor![gguf => "resampler.attn.v.bias" ].data],
57-
attn_o: [tensor![gguf => "resampler.attn.out.weight"].data ,
58-
tensor![gguf => "resampler.attn.out.bias" ].data],
59-
proj : tensor![gguf => "resampler.proj.weight" ].data ,
46+
wkv : tensor![gguf => "resampler.kv.weight" ].data ,
47+
q : tensor![gguf => "resampler.query" ].data ,
48+
ln_q : [tensor![gguf => "resampler.ln_q.weight" ].data ,
49+
tensor![gguf => "resampler.ln_q.bias" ].data],
50+
ln_kv : [tensor![gguf => "resampler.ln_kv.weight" ].data ,
51+
tensor![gguf => "resampler.ln_kv.bias" ].data],
52+
attn_q : [tensor![gguf => "resampler.attn.q.weight" ].data ,
53+
tensor![gguf => "resampler.attn.q.bias" ].data],
54+
attn_k : [tensor![gguf => "resampler.attn.k.weight" ].data ,
55+
tensor![gguf => "resampler.attn.k.bias" ].data],
56+
attn_v : [tensor![gguf => "resampler.attn.v.weight" ].data ,
57+
tensor![gguf => "resampler.attn.v.bias" ].data],
58+
attn_o : [tensor![gguf => "resampler.attn.out.weight"].data ,
59+
tensor![gguf => "resampler.attn.out.bias" ].data],
60+
ln_post: [tensor![gguf => "resampler.ln_post.weight" ].data ,
61+
tensor![gguf => "resampler.ln_post.bias" ].data],
62+
proj : tensor![gguf => "resampler.proj.weight" ].data ,
6063
}
6164
}
6265
}

0 commit comments

Comments
 (0)