@@ -42,7 +42,7 @@ def batch_patchify(
4242 pad : bool = True ,
4343) -> Tuple [torch .Tensor , Tuple [int , int ]]:
4444 B , C , H , W = x .shape
45- ph , pw = to_2tuple ( patch_size )
45+ ph , pw = patch_size
4646
4747 # Ensure the image is divisible by patch size
4848 if pad and (H % ph != 0 or W % pw != 0 ):
@@ -202,21 +202,20 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
202202 else :
203203 return img_size [0 ] // self .patch_size [0 ], img_size [1 ] // self .patch_size [1 ]
204204
205- def forward (self , x , patch_coord = None , patch_valid = None ):
205+ def forward (self , x : torch . Tensor , patch_coord : Optional [ torch . Tensor ] = None ):
206206 """Forward pass for combined embedding
207207
208208 Args:
209209 x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
210210 patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
211- patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
212211
213212 Returns:
214213 Embedded tensor with position encoding and class/register tokens applied
215214 If patch_type is provided, also returns attention mask
216215 """
217216 # Apply patch embedding
218217 naflex_grid_sizes : Optional [List [Tuple [int , int ]]] = None
219- grid_size : Optional [Tuple [ int , int ]] = None
218+ grid_size : Optional [List [ int ]] = None
220219
221220 B = x .shape [0 ]
222221 if self .is_linear :
@@ -227,7 +226,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
227226 # Calculate the appropriate grid size from coords
228227 max_y = patch_coord [:, :, 0 ].max (dim = 1 )[0 ] + 1
229228 max_x = patch_coord [:, :, 1 ].max (dim = 1 )[0 ] + 1
230- naflex_grid_sizes = [(h .item (), w .item ()) for h , w in zip (max_y , max_x )]
229+ naflex_grid_sizes = [(int ( h .item ()), int ( w .item () )) for h , w in zip (max_y , max_x )]
231230 else :
232231 _assert (x .ndim == 4 , 'Expecting 2D image input with input ndim == 4' )
233232 x , grid_size = batch_patchify (x , self .patch_size , pad = self .dynamic_img_pad )
@@ -257,6 +256,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
257256 if naflex_grid_sizes is not None :
258257 self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
259258 else :
259+ assert grid_size is not None
260260 self ._apply_learned_pos_embed (x , grid_size = grid_size )
261261 elif self .pos_embed_type == 'rope' :
262262 assert False , "ROPE not yet implemented"
@@ -287,15 +287,19 @@ def _apply_learned_naflex_pos_embed(
287287 orig_h , orig_w = self .pos_embed .shape [1 :3 ]
288288
289289 # Determine unique grid sizes
290- size_to_indices = {}
290+ size_to_indices : Dict [ Tuple [ int , int ], List [ int ]] = {}
291291 for bi , (h , w ) in enumerate (naflex_grid_sizes ):
292- if not (h , w ) in size_to_indices :
293- size_to_indices [(h , w )] = [bi ]
292+ #k = h << 16 | w # FIXME can get jit compat with this
293+ k = (h , w )
294+ if not k in size_to_indices :
295+ size_to_indices [k ] = [bi ]
294296 else :
295- size_to_indices [( h , w ) ].append (bi )
297+ size_to_indices [k ].append (bi )
296298
297299 # Handle each batch element separately with its own grid size
298- for (h , w ), batch_indices in size_to_indices .items ():
300+ for k , batch_indices in size_to_indices .items ():
301+ h , w = k
302+ #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
299303 # Interpolate only once for this (h, w)
300304 if (h == orig_h ) and (w == orig_w ):
301305 pos_embed_flat = self .pos_embed .reshape (orig_h * orig_w , - 1 )
@@ -315,7 +319,7 @@ def _apply_learned_naflex_pos_embed(
315319 def _apply_learned_pos_embed (
316320 self ,
317321 x : torch .Tensor ,
318- grid_size : Tuple [ int , int ],
322+ grid_size : List [ int ],
319323 ):
320324 orig_h , orig_w = self .pos_embed .shape [1 :3 ]
321325 if grid_size [0 ] != orig_h or grid_size [1 ] != orig_w :
@@ -340,7 +344,7 @@ def _apply_learned_pos_embed(
340344
341345@register_notrace_function
342346def create_attention_mask (
343- patch_valid : Optional [ torch .Tensor ] ,
347+ patch_valid : torch .Tensor ,
344348 num_prefix_tokens : int = 0 ,
345349 dtype : torch .dtype = torch .float32 ,
346350) -> torch .Tensor :
@@ -357,7 +361,7 @@ def create_attention_mask(
357361 Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
358362 or None if patch_type is None
359363 """
360- patch_valid = patch_valid .bool ( )
364+ patch_valid = patch_valid .to ( torch . bool )
361365 B = patch_valid .shape [0 ]
362366
363367 if num_prefix_tokens > 0 :
@@ -373,7 +377,7 @@ def create_attention_mask(
373377
374378@register_notrace_function
375379def create_attention_mask2 (
376- patch_valid : Optional [ torch .Tensor ] ,
380+ patch_valid : torch .Tensor ,
377381 num_prefix_tokens : int = 0 ,
378382 q_len : Optional [int ] = None ,
379383 dtype : torch .dtype = torch .float32 ,
@@ -411,7 +415,7 @@ def create_attention_mask2(
411415
412416@register_notrace_function
413417def create_pool_mask (
414- patch_valid : Optional [ torch .Tensor ] ,
418+ patch_valid :torch .Tensor ,
415419 dtype : torch .dtype = torch .float32 ,
416420) -> torch .Tensor :
417421 patch_valid = patch_valid .bool ()
@@ -773,8 +777,16 @@ def forward_features(
773777 patch_valid : Optional [torch .Tensor ] = None ,
774778 attn_mask : Optional [torch .Tensor ] = None ,
775779 ) -> torch .Tensor :
780+
781+ if attn_mask is None and patch_valid is not None :
782+ attn_mask = create_attention_mask (
783+ patch_valid ,
784+ num_prefix_tokens = self .num_prefix_tokens ,
785+ dtype = x .dtype
786+ )
787+
776788 # Pass through embedding module with patch coordinate/type support
777- x = self .embeds (x , patch_coord = patch_coord , patch_valid = patch_valid )
789+ x = self .embeds (x , patch_coord = patch_coord )
778790
779791 # Apply transformer blocks with masked attention if mask provided
780792 if attn_mask is not None :
@@ -827,7 +839,7 @@ def _pool(
827839
828840 # For max pooling with mask
829841 masked_x = x .clone ()
830- masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
842+ masked_x [~ patch_valid ] = - 1e4 # torch.finfo(masked_x.dtype).min
831843 masked_max = masked_x .max (dim = 1 )[0 ]
832844
833845 # Combine average and max
@@ -864,27 +876,23 @@ def forward(
864876 Returns:
865877 Model output tensor
866878 """
867- # Handle dictionary input from NaFlex collator
868- if isinstance (x , dict ):
869- assert patch_coord is None
870- assert patch_valid is None
871- # Extract the required components from the dictionary
879+ if isinstance (x , torch .Tensor ):
880+ patches = x
881+ else :
882+ # Handle dictionary input from NaFlex collator
872883 patch_coord = x ['patch_coord' ]
873884 patch_valid = x ['patch_valid' ]
874885 patches = x ['patches' ]
875886
876- if False :
877- # DEBUG, reconstruct patches
878- for i in range (len (patches )):
879- patch = patches [i ][patch_valid [i ]]
880- h = (patch_coord [i , :, 0 ].max () + 1 ).item ()
881- w = (patch_coord [i , :, 1 ].max () + 1 ).item ()
882- patch = patch .reshape (h , w , 16 , 16 , 3 ).permute (4 , 0 , 2 , 1 , 3 )
883- patch = patch .reshape (3 , h * 16 , w * 16 )
884- from torchvision .utils import save_image
885- save_image (patch , f'patch_{ i } .jpg' , normalize = True )
886- else :
887- patches = x
887+ # DEBUG, reconstruct patches
888+ # for i in range(len(patches)):
889+ # patch = patches[i][patch_valid[i]]
890+ # h = (patch_coord[i, :, 0].max() + 1).item()
891+ # w = (patch_coord[i, :, 1].max() + 1).item()
892+ # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
893+ # patch = patch.reshape(3, h*16, w*16)
894+ # from torchvision.utils import save_image
895+ # save_image(patch, f'patch_{i}.jpg', normalize=True)
888896
889897 # Create attention mask if patch_type is provided
890898 if patch_valid is not None :
0 commit comments