@@ -96,3 +96,163 @@ fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
9696
9797 ans
9898}
99+
100+ #[ cfg( test) ]
101+ mod test_pos_embd {
102+ use super :: * ;
103+
104+ mod c_like {
105+ use super :: * ;
106+ pub ( super ) fn pos_resampler ( n : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
107+ let d = 3584 ;
108+ let pos_w = w / d_patch;
109+ let pos_h = h / d_patch;
110+
111+ let mut ans = Tensor :: new ( ty:: F32 , & [ 1 , pos_w * pos_h, d] )
112+ . broadcast ( 0 , n)
113+ . map ( Blob :: new) ;
114+ let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < f32 > ( ) } ) else {
115+ panic ! ( )
116+ } ;
117+
118+ let pos_embed_t = get_2d_sincos_pos_embed ( d, ( pos_w, pos_h) ) ;
119+
120+ for i in 0 ..pos_w * pos_h {
121+ for j in 0 ..d {
122+ data[ i * d + j] = pos_embed_t[ i] [ j] ;
123+ }
124+ }
125+ ans
126+ }
127+
128+ fn get_2d_sincos_pos_embed ( embed_dim : usize , image_size : ( usize , usize ) ) -> Vec < Vec < f32 > > {
129+ let ( grid_h_size, grid_w_size) = image_size;
130+
131+ let mut grid_h: Vec < f32 > = ( 0 ..grid_h_size) . map ( |i| i as f32 ) . collect ( ) ;
132+ let mut grid_w: Vec < f32 > = ( 0 ..grid_w_size) . map ( |i| i as f32 ) . collect ( ) ;
133+
134+ let mut grid: Vec < Vec < f32 > > = vec ! [ vec![ 0.0 ; grid_w_size] ; grid_h_size] ;
135+ for h in 0 ..grid_h_size {
136+ for w in 0 ..grid_w_size {
137+ grid[ h] [ w] = grid_w[ w] ;
138+ }
139+ }
140+
141+ let mut grid_2d: Vec < Vec < Vec < f32 > > > = vec ! [ grid. clone( ) , grid. clone( ) ] ;
142+ for h in 0 ..grid_h_size {
143+ for w in 0 ..grid_w_size {
144+ grid_2d[ 0 ] [ h] [ w] = grid_h[ h] ;
145+ grid_2d[ 1 ] [ h] [ w] = grid_w[ w] ;
146+ }
147+ }
148+
149+ let pos_embed_3d = get_2d_sincos_pos_embed_from_grid ( embed_dim, grid_2d) ;
150+
151+ let ( H , W ) = image_size;
152+ let mut pos_embed_2d: Vec < Vec < f32 > > = vec ! [ vec![ 0.0 ; embed_dim] ; H * W ] ;
153+ for h in 0 ..H {
154+ for w in 0 ..W {
155+ pos_embed_2d[ w * H + h] = pos_embed_3d[ h] [ w] . clone ( ) ;
156+ }
157+ }
158+
159+ pos_embed_2d
160+ }
161+
162+ fn get_2d_sincos_pos_embed_from_grid (
163+ embed_dim : usize ,
164+ grid : Vec < Vec < Vec < f32 > > > ,
165+ ) -> Vec < Vec < Vec < f32 > > > {
166+ assert ! ( embed_dim % 2 == 0 ) ;
167+
168+ let emb_h = get_1d_sincos_pos_embed_from_grid_new ( embed_dim / 2 , grid[ 0 ] . clone ( ) ) ; // (H, W, D/2)
169+ let emb_w = get_1d_sincos_pos_embed_from_grid_new ( embed_dim / 2 , grid[ 1 ] . clone ( ) ) ; // (H, W, D/2)
170+
171+ let H = emb_h. len ( ) ;
172+ let W = emb_h[ 0 ] . len ( ) ;
173+ let mut emb: Vec < Vec < Vec < f32 > > > = vec ! [ vec![ vec![ 0.0 ; embed_dim] ; W ] ; H ] ;
174+
175+ for h in 0 ..H {
176+ for w in 0 ..W {
177+ for d in 0 ..( embed_dim / 2 ) {
178+ emb[ h] [ w] [ d] = emb_h[ h] [ w] [ d] ;
179+ emb[ h] [ w] [ d + embed_dim / 2 ] = emb_w[ h] [ w] [ d] ;
180+ }
181+ }
182+ }
183+
184+ emb
185+ }
186+
187+ fn get_1d_sincos_pos_embed_from_grid_new (
188+ embed_dim : usize ,
189+ pos : Vec < Vec < f32 > > ,
190+ ) -> Vec < Vec < Vec < f32 > > > {
191+ assert ! ( embed_dim % 2 == 0 ) ;
192+ let H = pos. len ( ) ;
193+ let W = pos[ 0 ] . len ( ) ;
194+
195+ let mut omega: Vec < f32 > = ( 0 ..embed_dim / 2 )
196+ . map ( |i| 1.0 / 10000.0f32 . powi ( i as i32 / ( embed_dim / 2 ) as i32 ) )
197+ . collect ( ) ;
198+
199+ let mut emb: Vec < Vec < Vec < f32 > > > = vec ! [ vec![ vec![ 0.0 ; embed_dim] ; W ] ; H ] ;
200+ for h in 0 ..H {
201+ for w in 0 ..W {
202+ for d in 0 ..( embed_dim / 2 ) {
203+ let out_value = pos[ h] [ w] * omega[ d] ;
204+ emb[ h] [ w] [ d] = out_value. sin ( ) ;
205+ emb[ h] [ w] [ d + embed_dim / 2 ] = out_value. cos ( ) ;
206+ }
207+ }
208+ }
209+
210+ emb
211+ }
212+ }
213+
214+ mod rust_style {
215+ use super :: * ;
216+ pub ( super ) fn pos_resampler ( n : usize , [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
217+ let d = 3584 ;
218+ assert ! ( d % 4 == 0 ) ;
219+
220+ let pos_w = w / d_patch;
221+ let pos_h = h / d_patch;
222+
223+ let mut ans = Tensor :: new ( ty:: F32 , & [ 1 , pos_w * pos_h, d] ) . map ( Blob :: new) ;
224+ let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < f32 > ( ) } ) else {
225+ panic ! ( )
226+ } ;
227+ set_2d_sincos_pos_embed ( data, d, ( pos_w, pos_h) ) ;
228+
229+ ans. broadcast ( 0 , n)
230+ }
231+
232+ fn set_2d_sincos_pos_embed ( data : & mut [ f32 ] , d : usize , ( h, w) : ( usize , usize ) ) {
233+ for r in 0 ..h {
234+ for c in 0 ..w {
235+ let data = & mut data[ ( c * h + r) * d..] [ ..d] ;
236+ let d = d / 4 ;
237+
238+ for i in 0 ..d {
239+ let ( sin, cos) = ( r as f32 ) . sin_cos ( ) ;
240+ data[ 0 * d..] [ i] = sin;
241+ data[ 1 * d..] [ i] = cos;
242+
243+ let ( sin, cos) = ( c as f32 ) . sin_cos ( ) ;
244+ data[ 2 * d..] [ i] = sin;
245+ data[ 3 * d..] [ i] = cos;
246+ }
247+ }
248+ }
249+ }
250+ }
251+
252+ #[ test]
253+ fn test_eq ( ) {
254+ let a = c_like:: pos_resampler ( 4 , [ 336 , 224 ] , 14 ) . take ( ) ;
255+ let b = rust_style:: pos_resampler ( 4 , [ 336 , 224 ] , 14 ) . take ( ) ;
256+ assert_eq ! ( & * a, & * b) ;
257+ }
258+ }
0 commit comments