Skip to content

Commit 9bfe2a5

Browse files
committed
feat(clip): 分片图片批量编码
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 2e5d211 commit 9bfe2a5

File tree

5 files changed

+67
-30
lines changed

5 files changed

+67
-30
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ itertools = "0.13"
2828
build-script-cfg = "0.0"
2929

3030
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "48d36c5" }
31-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "1b08473", default-features = false }
31+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "d73a53e", default-features = false }
3232

3333
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "6846d52" }
3434
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e2ec203" }

models/clip/common-cpu/src/test_infer.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,15 @@ fn test_infer() {
5656
)
5757
.unwrap();
5858

59-
let [x, y] = slices.grid();
60-
for i in 0..y {
61-
for j in 0..x {
62-
let patch = slices.patch(j, i);
63-
worker
64-
.launch(
65-
ClipArgs {
66-
raw: patch.to_nchw(),
67-
},
68-
&mut [],
69-
&ThisThread,
70-
)
71-
.unwrap();
72-
}
59+
if let Some(patches) = slices.patches_nchw() {
60+
worker
61+
.launch(
62+
ClipArgs {
63+
raw: patches.map_slice(),
64+
},
65+
&mut [],
66+
&ThisThread,
67+
)
68+
.unwrap();
7369
}
7470
}

models/clip/common/src/compute.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ use operators::{
33
conv::{self, Conv},
44
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode,
55
};
6-
use std::ops::{Deref, DerefMut};
6+
use std::{
7+
ops::{Deref, DerefMut},
8+
time::Instant,
9+
};
710
use tensor::Tensor;
811

912
pub trait Operators {
@@ -60,6 +63,7 @@ where
6063
where
6164
QA: QueueAlloc<Hardware = Ops::Hardware>,
6265
{
66+
let time = Instant::now();
6367
let Args { raw } = args;
6468
let queue = queue_alloc.queue();
6569

@@ -74,7 +78,13 @@ where
7478
};
7579

7680
let mut embd = Tensor::new(dt_embd, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s));
77-
self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)
81+
self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)?;
82+
83+
if self.debug {
84+
println!("encode {n} x {h} x {w} image in {:?}", time.elapsed());
85+
}
86+
87+
Ok(())
7888
}
7989
}
8090

models/clip/common/src/image.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use def::*;
1+
use common::{borrow, own, Contiguous};
2+
use def::*;
23
use gguf::ggml_quants::{
34
digit_layout::{types as ty, DigitLayout},
45
f16,
@@ -7,7 +8,7 @@ use image::ImageReader;
78
use itertools::izip;
89
use rayon::iter::{IntoParallelIterator, ParallelIterator};
910
use std::{iter::zip, ops::Deref, path::Path, slice::from_raw_parts_mut};
10-
use tensor::{Blob, Tensor};
11+
use tensor::{rearrange, Blob, Tensor};
1112

1213
#[repr(transparent)]
1314
pub struct Image<T>(Tensor<T>);
@@ -161,11 +162,7 @@ where
161162

162163
/// NHWC rgb Tensor -> NCHW value Tensor
163164
pub fn to_nchw(&self) -> Tensor<&[u8]> {
164-
self.0
165-
.destruct_array()
166-
.map(|t| &**t)
167-
.transpose(&[2, 0, 1])
168-
.tile(0, &[1, 3])
165+
rgb_to_chw(&self.0).tile(0, &[1, 3])
169166
}
170167
}
171168

@@ -198,6 +195,19 @@ impl ImageGrid {
198195
)
199196
}
200197

198+
pub fn patches_nchw(&self) -> Option<Tensor<Contiguous<Blob>>> {
199+
self.grid.as_ref().map(|data| {
200+
let xychw = rgb_to_chw(data);
201+
if let Some(nchw) = xychw.as_ref().merge(0..2) {
202+
nchw.map(|s| borrow(s))
203+
} else {
204+
let mut blob = Tensor::new(xychw.dt(), xychw.shape()).map(Blob::new);
205+
rearrange(&mut blob, &xychw);
206+
blob.merge(0..2).unwrap().map(own)
207+
}
208+
})
209+
}
210+
201211
/// [urgb] 转 [frgb]
202212
pub fn normalize(&self, dt: DigitLayout, mean: frgb96, std: frgb96) -> Self {
203213
let dt = match dt {
@@ -317,6 +327,16 @@ where
317327
ans
318328
}
319329

330+
fn rgb_to_chw<T>(data: &Tensor<T>) -> Tensor<&[u8]>
331+
where
332+
T: Deref<Target = [u8]>,
333+
{
334+
let ndim = data.shape().len();
335+
data.map_slice()
336+
.destruct_array()
337+
.transpose(&[ndim, ndim - 2, ndim - 1])
338+
}
339+
320340
#[test]
321341
fn test() {
322342
use std::time::Instant;

tensor/src/lib.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,30 @@ impl Tensor<usize> {
3434
/// access
3535
impl<T> Tensor<T> {
3636
/// 打开数组数据类型
37-
pub fn destruct_array(&self) -> Tensor<&T> {
37+
pub fn destruct_array(self) -> Self {
3838
use ggus::ggml_quants::digit_layout::LayoutContent::{Real, Unsigned};
3939
use std::iter::once;
4040

41-
let len = self.dt.group_size();
42-
let dt = match self.dt.decode() {
41+
let Self {
42+
dt,
43+
layout,
44+
physical,
45+
} = self;
46+
47+
let len = dt.group_size();
48+
let dt = match dt.decode() {
4349
Unsigned { width } if len > 1 => DigitLayout::unsigned(width as _, 1),
4450
Real { exponent, mantissa } if len > 1 => {
4551
DigitLayout::real(exponent as _, mantissa as _, 1)
4652
}
47-
_ => return self.as_ref(),
53+
_ => {
54+
return Self {
55+
dt,
56+
layout,
57+
physical,
58+
}
59+
}
4860
};
49-
let layout = &self.layout;
5061
let shape = layout
5162
.shape()
5263
.iter()
@@ -63,7 +74,7 @@ impl<T> Tensor<T> {
6374
Tensor {
6475
dt,
6576
layout: ArrayLayout::new(&shape, &strides, offset),
66-
physical: &self.physical,
77+
physical,
6778
}
6879
}
6980

0 commit comments

Comments
 (0)