11use crate :: { Operators , Weights } ;
22use clip:: { ClipArgs , ClipMeta , ClipStorage , ClipWorker , Image , Tensor , D_POS_EMBD } ;
3- use gguf:: { ggml_quants:: digit_layout:: types as ty, GGufModel } ;
3+ use gguf:: {
4+ ggml_quants:: { digit_layout:: types as ty, f16} ,
5+ GGufModel ,
6+ } ;
47use operators:: {
58 common_cpu:: { Cpu , ThisThread } ,
69 Blob ,
@@ -53,22 +56,24 @@ fn test_infer() {
5356 . launch (
5457 ClipArgs {
5558 raw : whole. to_nchw ( ) ,
56- pos : pos70 ( 1 , whole. shape ( ) , d_patch) . map_slice ( ) ,
59+ pos : pos70 ( whole. shape ( ) , d_patch) . map_slice ( ) ,
60+ pos_resampler : pos_resampler ( 3584 , whole. shape ( ) , d_patch) . map_slice ( ) ,
5761 } ,
5862 & mut [ ] ,
5963 & ThisThread ,
6064 )
6165 . unwrap ( ) ;
6266
6367 if let Some ( patches) = slices. patches_nchw ( ) {
64- let & [ n , 3 , h, w] = patches. shape ( ) else {
68+ let & [ _ , 3 , h, w] = patches. shape ( ) else {
6569 unreachable ! ( )
6670 } ;
6771 worker
6872 . launch (
6973 ClipArgs {
7074 raw : patches. map_slice ( ) ,
71- pos : pos70 ( n, [ w, h] , d_patch) . map_slice ( ) ,
75+ pos : pos70 ( [ w, h] , d_patch) . map_slice ( ) ,
76+ pos_resampler : pos_resampler ( 3584 , [ w, h] , d_patch) . map_slice ( ) ,
7277 } ,
7378 & mut [ ] ,
7479 & ThisThread ,
@@ -77,7 +82,7 @@ fn test_infer() {
7782 }
7883}
7984
80- fn pos70 ( n : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
85+ fn pos70 ( [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
8186 let w = w / d_patch;
8287 let h = h / d_patch;
8388
@@ -95,15 +100,15 @@ fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
95100 data[ i] = ( y * D_POS_EMBD + x) as _ ;
96101 }
97102
98- ans. broadcast ( 0 , n )
103+ ans
99104}
100105
101- fn pos_resampler ( d : usize , n : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
106+ fn pos_resampler ( d : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
102107 let w = w / d_patch;
103108 let h = h / d_patch;
104109
105- let mut ans = Tensor :: new ( ty:: F32 , & [ 1 , h * w, d] ) . map ( Blob :: new) ;
106- let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < f32 > ( ) } ) else {
110+ let mut ans = Tensor :: new ( ty:: F16 , & [ 1 , h * w, d] ) . map ( Blob :: new) ;
111+ let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < f16 > ( ) } ) else {
107112 panic ! ( )
108113 } ;
109114
@@ -118,15 +123,15 @@ fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tens
118123 let d = d / 4 ;
119124 for i in 0 ..d {
120125 let ( sin, cos) = cache[ c * d + i] ;
121- data[ 0 * d..] [ i] = sin;
122- data[ 1 * d..] [ i] = cos;
126+ data[ 0 * d..] [ i] = f16 :: from_f32 ( sin) ;
127+ data[ 1 * d..] [ i] = f16 :: from_f32 ( cos) ;
123128 let ( sin, cos) = cache[ r * d + i] ;
124- data[ 2 * d..] [ i] = sin;
125- data[ 3 * d..] [ i] = cos;
129+ data[ 2 * d..] [ i] = f16 :: from_f32 ( sin) ;
130+ data[ 3 * d..] [ i] = f16 :: from_f32 ( cos) ;
126131 }
127132 }
128133
129- ans. broadcast ( 0 , n )
134+ ans
130135}
131136
132137fn sin_cos_cache ( max_idx : usize , d : usize , theta : f32 ) -> Vec < ( f32 , f32 ) > {
0 commit comments