@@ -424,6 +424,53 @@ def create_attention_mask(
424424 return mask_float
425425
426426
427+ @register_notrace_function
428+ def global_pool_naflex (
429+ x : torch .Tensor ,
430+ patch_valid : Optional [torch .Tensor ] = None ,
431+ pool_type : str = 'token' ,
432+ num_prefix_tokens : int = 1 ,
433+ ):
434+ if patch_valid is None or pool_type not in ('avg' , 'avgmax' , 'max' ):
435+ # Fall back to standard pooling
436+ x = global_pool_nlc (x , pool_type = pool_type , num_prefix_tokens = num_prefix_tokens )
437+ return x
438+
439+ # For NaFlex mode, we need to apply masked pooling to exclude padding tokens
440+ # Extract only the patch part of the mask (excluding prefix tokens)
441+ if num_prefix_tokens > 0 :
442+ # Apply the mask to extract only valid tokens
443+ x = x [:, num_prefix_tokens :] # prefix tokens not included in pooling
444+
445+ patch_valid_float = patch_valid .to (x .dtype )
446+ if pool_type == 'avg' :
447+ # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
448+ masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
449+ valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
450+ pooled = masked_sums / valid_counts
451+ return pooled
452+ elif pool_type == 'avgmax' :
453+ # For avgmax, compute masked average and masked max
454+ masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
455+ valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
456+ masked_avg = masked_sums / valid_counts
457+
458+ # For max pooling we set masked positions to large negative value
459+ masked_x = x .clone ()
460+ masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
461+ masked_max = masked_x .amax (dim = 1 )
462+
463+ # Combine average and max
464+ return 0.5 * (masked_avg + masked_max )
465+ elif pool_type == 'max' :
466+ # For max pooling we set masked positions to large negative value
467+ masked_x = x .clone ()
468+ masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
469+ return masked_x .amax (dim = 1 )
470+ else :
471+ assert False
472+
473+
427474class VisionTransformerFlex (nn .Module ):
428475 """ Vision Transformer (Na)Flex
429476
@@ -817,38 +864,13 @@ def _pool(
817864 return x
818865
819866 pool_type = self .global_pool if pool_type is None else pool_type
820-
821- # Handle padding mask for average pooling
822- if patch_valid is not None and pool_type in ('avg' , 'avgmax' ):
823- # For NaFlex mode, we need to apply masked pooling to exclude padding tokens
824- # Extract only the patch part of the mask (excluding prefix tokens)
825- if self .num_prefix_tokens > 0 :
826- # Apply the mask to extract only valid tokens
827- x = x [:, self .num_prefix_tokens :] # prefix tokens not included in pooling
828-
829- patch_valid_float = patch_valid .to (x .dtype )
830- if pool_type == 'avg' :
831- # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
832- masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
833- valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
834- pooled = masked_sums / valid_counts
835- return pooled
836- elif pool_type == 'avgmax' :
837- # For avgmax, compute masked average and masked max
838- masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
839- valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
840- masked_avg = masked_sums / valid_counts
841-
842- # For max pooling we set masked positions to large negative value
843- masked_x = x .clone ()
844- masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
845- masked_max = masked_x .max (dim = 1 )[0 ]
846-
847- # Combine average and max
848- return 0.5 * (masked_avg + masked_max )
849867
850- # Fall back to standard pooling
851- x = global_pool_nlc (x , pool_type = pool_type , num_prefix_tokens = self .num_prefix_tokens )
868+ x = global_pool_naflex (
869+ x ,
870+ patch_valid ,
871+ pool_type = pool_type ,
872+ num_prefix_tokens = self .num_prefix_tokens ,
873+ )
852874 return x
853875
854876 def forward_head (
@@ -897,14 +919,11 @@ def forward(
897919 patches = x
898920
899921 # Create attention mask if patch_type is provided
900- if patch_valid is not None :
901- attn_mask = create_attention_mask (
902- patch_valid ,
903- num_prefix_tokens = self .num_prefix_tokens ,
904- dtype = patches .dtype
905- )
906- else :
907- attn_mask = None
922+ attn_mask = create_attention_mask (
923+ patch_valid ,
924+ num_prefix_tokens = self .num_prefix_tokens ,
925+ dtype = patches .dtype ,
926+ )
908927
909928 # Forward features with mask
910929 x = self .forward_features (
0 commit comments