@@ -78,21 +78,63 @@ fn test_infer() {
7878}
7979
8080fn pos70 ( n : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
81- let pos_w = w / d_patch;
82- let pos_h = h / d_patch;
81+ let w = w / d_patch;
82+ let h = h / d_patch;
8383
84- let mut ans = Tensor :: new ( ty:: U32 , & [ 1 , pos_w * pos_h] )
85- . broadcast ( 0 , n)
86- . map ( Blob :: new) ;
84+ let mut ans = Tensor :: new ( ty:: U32 , & [ 1 , h * w] ) . map ( Blob :: new) ;
8785 let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < u32 > ( ) } ) else {
8886 panic ! ( )
8987 } ;
9088
91- for i in 0 ..pos_h * pos_w {
92- let y = ( i / pos_w) * D_POS_EMBD / pos_h;
93- let x = ( i % pos_w) * D_POS_EMBD / pos_w;
89+ for i in 0 ..h * w {
90+ let r = i / w;
91+ let c = i % w;
92+
93+ let y = r * D_POS_EMBD / h;
94+ let x = c * D_POS_EMBD / w;
9495 data[ i] = ( y * D_POS_EMBD + x) as _ ;
9596 }
9697
97- ans
98+ ans. broadcast ( 0 , n)
99+ }
100+
101+ fn pos_resampler ( d : usize , n : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
102+ let w = w / d_patch;
103+ let h = h / d_patch;
104+
105+ let mut ans = Tensor :: new ( ty:: F32 , & [ 1 , h * w, d] ) . map ( Blob :: new) ;
106+ let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < f32 > ( ) } ) else {
107+ panic ! ( )
108+ } ;
109+
110+ assert ! ( d % 4 == 0 ) ;
111+ let cache = sin_cos_cache ( w. max ( h) , d / 4 , 1e4 ) ;
112+
113+ for i in 0 ..h * w {
114+ let r = i / w;
115+ let c = i % w;
116+
117+ let data = & mut data[ i * d..] [ ..d] ;
118+ let d = d / 4 ;
119+ for i in 0 ..d {
120+ let ( sin, cos) = cache[ c * d + i] ;
121+ data[ 0 * d..] [ i] = sin;
122+ data[ 1 * d..] [ i] = cos;
123+ let ( sin, cos) = cache[ r * d + i] ;
124+ data[ 2 * d..] [ i] = sin;
125+ data[ 3 * d..] [ i] = cos;
126+ }
127+ }
128+
129+ ans. broadcast ( 0 , n)
130+ }
131+
132+ fn sin_cos_cache ( max_idx : usize , d : usize , theta : f32 ) -> Vec < ( f32 , f32 ) > {
133+ ( 0 ..max_idx * d)
134+ . map ( |i| {
135+ let a = ( i / d) as f32 ;
136+ let b = ( i % d) as f32 ;
137+ ( a * theta. powf ( -( b / d as f32 ) ) ) . sin_cos ( )
138+ } )
139+ . collect ( )
98140}
0 commit comments