@@ -297,131 +297,109 @@ def _apply_learned_naflex_pos_embed(
297297 size_to_indices [k ].append (bi )
298298
299299 # Handle each batch element separately with its own grid size
300+ pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ) # B,C,H,W
300301 for k , batch_indices in size_to_indices .items ():
301302 h , w = k
302303 #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
303304 # Interpolate only once for this (h, w)
304305 if (h == orig_h ) and (w == orig_w ):
305- pos_embed_flat = self .pos_embed .reshape (orig_h * orig_w , - 1 )
306+ pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
306307 else :
307- pos_embed_resized = F .interpolate (
308- self . pos_embed . permute ( 0 , 3 , 1 , 2 ), # B,C,H,W
308+ pos_embed_flat = F .interpolate (
309+ pos_embed_nchw ,
309310 size = (h , w ),
310311 mode = self .pos_embed_interp_mode ,
311312 align_corners = False ,
312313 antialias = True ,
313- )
314- pos_embed_flat = pos_embed_resized .permute (0 , 2 , 3 , 1 ).reshape (h * w , - 1 )
314+ ).flatten (2 ).transpose (1 , 2 )
315315
316- seq_len = min (x .shape [1 ], pos_embed_flat .shape [0 ])
317- x [batch_indices , :seq_len ].add_ (pos_embed_flat [:seq_len ])
316+ seq_len = min (x .shape [1 ], pos_embed_flat .shape [1 ])
317+ x [batch_indices , :seq_len ].add_ (pos_embed_flat [:, : seq_len ])
318318
319319 def _apply_learned_pos_embed (
320320 self ,
321321 x : torch .Tensor ,
322322 grid_size : List [int ],
323323 ):
324324 orig_h , orig_w = self .pos_embed .shape [1 :3 ]
325- if grid_size [0 ] != orig_h or grid_size [1 ] != orig_w :
325+ if grid_size [0 ] == orig_h or grid_size [1 ] == orig_w :
326+ # No resize needed, just flatten
327+ pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
328+ else :
326329 # Resize if needed - directly using F.interpolate
327- pos_embed = F .interpolate (
330+ pos_embed_flat = F .interpolate (
328331 self .pos_embed .permute (0 , 3 , 1 , 2 ), # B,C,H,W
329332 size = grid_size ,
330333 mode = self .pos_embed_interp_mode ,
331334 align_corners = False ,
332335 antialias = True ,
333- )
334- # Convert back and flatten
335- pos_embed = pos_embed .permute (0 , 2 , 3 , 1 )
336- pos_embed = pos_embed .reshape (1 , grid_size [0 ] * grid_size [1 ], - 1 )
337-
338- else :
339- # No resize needed, just flatten
340- pos_embed = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
336+ ).flatten (2 ).transpose (1 , 2 )
341337
342- x .add_ (pos_embed )
338+ x .add_ (pos_embed_flat )
343339
344340
345341@register_notrace_function
346342def create_attention_mask (
347- patch_valid : torch .Tensor ,
348- num_prefix_tokens : int = 0 ,
349- dtype : torch .dtype = torch .float32 ,
350- ) -> torch .Tensor :
351- """Create attention mask from patch type information.
352-
353- Used for NaFlex mode to handle variable token counts and padding tokens.
354-
355- Args:
356- patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
357- num_prefix_tokens: Number of prefix tokens (class token, register tokens)
358- dtype: Dtype of the attention mask
359-
360- Returns:
361- Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
362- or None if patch_type is None
363- """
364- patch_valid = patch_valid .to (torch .bool )
365- B = patch_valid .shape [0 ]
366-
367- if num_prefix_tokens > 0 :
368- prefix_valid = patch_valid .new_ones ((B , num_prefix_tokens ))
369- patch_valid = torch .cat ([prefix_valid , patch_valid ], dim = 1 )
370-
371- mask_bool = (patch_valid .unsqueeze (- 1 ) & patch_valid .unsqueeze (1 )).unsqueeze (1 )
372- mask_float = torch .zeros_like (mask_bool , dtype = dtype )
373- mask_float .masked_fill_ (~ mask_bool , torch .finfo (mask_float .dtype ).min )
374-
375- return mask_float
376-
377-
378- @register_notrace_function
379- def create_attention_mask2 (
380343 patch_valid : torch .Tensor ,
381344 num_prefix_tokens : int = 0 ,
345+ symmetric : bool = True ,
382346 q_len : Optional [int ] = None ,
383347 dtype : torch .dtype = torch .float32 ,
384- ) -> Optional [torch .Tensor ]:
385- """Create expanded attention mask from patch validity info.
348+ ) -> torch .Tensor :
349+ """Creates an attention mask from patch validity information.
350+
351+ Supports two modes controlled by `symmetric`:
352+ 1. `symmetric=True` (default): Creates a symmetric mask of shape
353+ [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
354+ both token i and token j are valid. Suitable for standard self-attention.
355+ 2. `symmetric=False`: Creates a potentially non-square mask of shape
356+ [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
357+ the key/value token k is valid. Query token validity is not checked
358+ in the mask itself. Useful for cross-attention or specific self-attention
359+ implementations `q_len` can be specified.
386360
387361 Used for NaFlex mode to handle variable token counts and padding tokens.
388362
389363 Args:
390- patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
364+ patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
391365 num_prefix_tokens: Number of prefix tokens (class token, register tokens)
392- q_len: Length override for query sequence
393- dtype: Dtype of the attention mask
366+ to prepend, which are always considered valid.
367+ symmetric: If True, create a symmetric mask.
368+ If False, create an expanded mask based only on key/value validity.
369+ q_len: Query sequence length override. Only used when `symmetric` is False.
370+ Defaults to the key/value sequence length (`kv_len`) if None.
371+ dtype: Dtype of the output attention mask (e.g., torch.float32).
394372
395373 Returns:
396- Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
397- or None if patch_type is None
374+ Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
375+ Shape is [B, 1, seq_len, seq_len] if symmetric=True,
376+ or [B, 1, q_len, kv_len] if symmetric=False.
398377 """
399- patch_valid = patch_valid .bool ()
400- B , kv_len = patch_valid .shape
378+ patch_valid = patch_valid .bool () # Ensure boolean type
379+ B , N = patch_valid .shape
380+ kv_len = N # Initial key/value length is the number of patches
401381
382+ # Prepend prefix tokens if any
402383 if num_prefix_tokens > 0 :
403- prefix_valid = patch_valid .new_ones ((B , num_prefix_tokens ))
384+ # Create prefix validity tensor on the same device/dtype base as patch_valid
385+ prefix_valid = patch_valid .new_ones ((B , num_prefix_tokens ), dtype = torch .bool )
386+ # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
404387 patch_valid = torch .cat ([prefix_valid , patch_valid ], dim = 1 )
405- kv_len = patch_valid .shape [1 ]
406-
407- q_len = q_len if q_len is not None else kv_len
408-
409- mask_bool = patch_valid [:, None , None , :].expand (B , 1 , q_len , kv_len ).to (dtype )
410- mask_float = torch .zeros_like (mask_bool , dtype = dtype )
411- mask_float .masked_fill_ (~ mask_bool , torch .finfo (mask_float .dtype ).min )
412-
413- return mask_float
388+ kv_len += num_prefix_tokens # Update total key/value sequence length
414389
390+ if symmetric :
391+ # Symmetric mask is True where BOTH query and key are valid
392+ mask_bool = patch_valid .unsqueeze (- 1 ) & patch_valid .unsqueeze (1 )
393+ mask_bool = mask_bool .unsqueeze (1 ) # Add head dimension: [B, 1, seq_len, seq_len]
394+ else :
395+ # Expanded mask
396+ q_len = q_len or kv_len
397+ mask_bool = patch_valid [:, None , None , :].expand (B , 1 , q_len , kv_len )
415398
416- @register_notrace_function
417- def create_pool_mask (
418- patch_valid :torch .Tensor ,
419- dtype : torch .dtype = torch .float32 ,
420- ) -> torch .Tensor :
421- patch_valid = patch_valid .bool ()
422- mask_bool = patch_valid [:, None , None , :]
399+ # Create the float mask and apply masking using additive mask convention
423400 mask_float = torch .zeros_like (mask_bool , dtype = dtype )
424- mask_float .masked_fill_ (~ mask_bool , torch .finfo (mask_float .dtype ).min )
401+ # Fill with negative infinity where mask_bool is False (masked positions)
402+ mask_float .masked_fill_ (~ mask_bool , torch .finfo (dtype ).min )
425403
426404 return mask_float
427405
@@ -809,7 +787,12 @@ def _pool(
809787 ) -> torch .Tensor :
810788 if self .attn_pool is not None :
811789 # For attention pooling, we need to pass the mask for NaFlex models
812- attn_mask = create_pool_mask (patch_valid , dtype = x .dtype )
790+ attn_mask = create_attention_mask (
791+ patch_valid ,
792+ symmetric = False ,
793+ q_len = 1 ,
794+ dtype = x .dtype ,
795+ )
813796 x = self .attn_pool (x [:, self .num_prefix_tokens :], attn_mask = attn_mask )
814797 return x
815798
@@ -839,7 +822,7 @@ def _pool(
839822
840823 # For max pooling with mask
841824 masked_x = x .clone ()
842- masked_x [~ patch_valid ] = - 1e4 # torch.finfo(masked_x.dtype).min
825+ masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
843826 masked_max = masked_x .max (dim = 1 )[0 ]
844827
845828 # Combine average and max
@@ -876,9 +859,7 @@ def forward(
876859 Returns:
877860 Model output tensor
878861 """
879- if isinstance (x , torch .Tensor ):
880- patches = x
881- else :
862+ if isinstance (x , Dict ):
882863 # Handle dictionary input from NaFlex collator
883864 patch_coord = x ['patch_coord' ]
884865 patch_valid = x ['patch_valid' ]
@@ -893,6 +874,8 @@ def forward(
893874 # patch = patch.reshape(3, h*16, w*16)
894875 # from torchvision.utils import save_image
895876 # save_image(patch, f'patch_{i}.jpg', normalize=True)
877+ else :
878+ patches = x
896879
897880 # Create attention mask if patch_type is provided
898881 if patch_valid is not None :
0 commit comments