1- use def:: * ;
1+ use common:: { borrow, own, Contiguous } ;
2+ use def:: * ;
23use gguf:: ggml_quants:: {
34 digit_layout:: { types as ty, DigitLayout } ,
45 f16,
@@ -7,7 +8,7 @@ use image::ImageReader;
78use itertools:: izip;
89use rayon:: iter:: { IntoParallelIterator , ParallelIterator } ;
910use std:: { iter:: zip, ops:: Deref , path:: Path , slice:: from_raw_parts_mut} ;
10- use tensor:: { Blob , Tensor } ;
11+ use tensor:: { rearrange , Blob , Tensor } ;
1112
1213#[ repr( transparent) ]
1314pub struct Image < T > ( Tensor < T > ) ;
@@ -161,11 +162,7 @@ where
161162
162163 /// NHWC rgb Tensor -> NCHW value Tensor
163164 pub fn to_nchw ( & self ) -> Tensor < & [ u8 ] > {
164- self . 0
165- . destruct_array ( )
166- . map ( |t| & * * t)
167- . transpose ( & [ 2 , 0 , 1 ] )
168- . tile ( 0 , & [ 1 , 3 ] )
165+ rgb_to_chw ( & self . 0 ) . tile ( 0 , & [ 1 , 3 ] )
169166 }
170167}
171168
@@ -198,6 +195,19 @@ impl ImageGrid {
198195 )
199196 }
200197
198+ pub fn patches_nchw ( & self ) -> Option < Tensor < Contiguous < Blob > > > {
199+ self . grid . as_ref ( ) . map ( |data| {
200+ let xychw = rgb_to_chw ( data) ;
201+ if let Some ( nchw) = xychw. as_ref ( ) . merge ( 0 ..2 ) {
202+ nchw. map ( |s| borrow ( s) )
203+ } else {
204+ let mut blob = Tensor :: new ( xychw. dt ( ) , xychw. shape ( ) ) . map ( Blob :: new) ;
205+ rearrange ( & mut blob, & xychw) ;
206+ blob. merge ( 0 ..2 ) . unwrap ( ) . map ( own)
207+ }
208+ } )
209+ }
210+
201211 /// [urgb] 转 [frgb]
202212 pub fn normalize ( & self , dt : DigitLayout , mean : frgb96 , std : frgb96 ) -> Self {
203213 let dt = match dt {
@@ -317,6 +327,16 @@ where
317327 ans
318328}
319329
330+ fn rgb_to_chw < T > ( data : & Tensor < T > ) -> Tensor < & [ u8 ] >
331+ where
332+ T : Deref < Target = [ u8 ] > ,
333+ {
334+ let ndim = data. shape ( ) . len ( ) ;
335+ data. map_slice ( )
336+ . destruct_array ( )
337+ . transpose ( & [ ndim, ndim - 2 , ndim - 1 ] )
338+ }
339+
320340#[ test]
321341fn test ( ) {
322342 use std:: time:: Instant ;
0 commit comments