Skip to content

Commit 162f492

Browse files
committed
Move naflex global pool into one fn that can be marked notrace
1 parent 2ad75e8 commit 162f492

File tree

1 file changed

+58
-39
lines changed

1 file changed

+58
-39
lines changed

timm/models/vision_transformer_flex.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,53 @@ def create_attention_mask(
424424
return mask_float
425425

426426

427+
@register_notrace_function
428+
def global_pool_naflex(
429+
x: torch.Tensor,
430+
patch_valid: Optional[torch.Tensor] = None,
431+
pool_type: str = 'token',
432+
num_prefix_tokens: int = 1,
433+
):
434+
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
435+
# Fall back to standard pooling
436+
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
437+
return x
438+
439+
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
440+
# Extract only the patch part of the mask (excluding prefix tokens)
441+
if num_prefix_tokens > 0:
442+
# Apply the mask to extract only valid tokens
443+
x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling
444+
445+
patch_valid_float = patch_valid.to(x.dtype)
446+
if pool_type == 'avg':
447+
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
448+
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
449+
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
450+
pooled = masked_sums / valid_counts
451+
return pooled
452+
elif pool_type == 'avgmax':
453+
# For avgmax, compute masked average and masked max
454+
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
455+
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
456+
masked_avg = masked_sums / valid_counts
457+
458+
# For max pooling we set masked positions to large negative value
459+
masked_x = x.clone()
460+
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
461+
masked_max = masked_x.amax(dim=1)
462+
463+
# Combine average and max
464+
return 0.5 * (masked_avg + masked_max)
465+
elif pool_type == 'max':
466+
# For max pooling we set masked positions to large negative value
467+
masked_x = x.clone()
468+
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
469+
return masked_x.amax(dim=1)
470+
else:
471+
assert False
472+
473+
427474
class VisionTransformerFlex(nn.Module):
428475
""" Vision Transformer (Na)Flex
429476
@@ -817,38 +864,13 @@ def _pool(
817864
return x
818865

819866
pool_type = self.global_pool if pool_type is None else pool_type
820-
821-
# Handle padding mask for average pooling
822-
if patch_valid is not None and pool_type in ('avg', 'avgmax'):
823-
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
824-
# Extract only the patch part of the mask (excluding prefix tokens)
825-
if self.num_prefix_tokens > 0:
826-
# Apply the mask to extract only valid tokens
827-
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
828-
829-
patch_valid_float = patch_valid.to(x.dtype)
830-
if pool_type == 'avg':
831-
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
832-
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
833-
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
834-
pooled = masked_sums / valid_counts
835-
return pooled
836-
elif pool_type == 'avgmax':
837-
# For avgmax, compute masked average and masked max
838-
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
839-
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
840-
masked_avg = masked_sums / valid_counts
841-
842-
# For max pooling we set masked positions to large negative value
843-
masked_x = x.clone()
844-
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
845-
masked_max = masked_x.max(dim=1)[0]
846-
847-
# Combine average and max
848-
return 0.5 * (masked_avg + masked_max)
849867

850-
# Fall back to standard pooling
851-
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
868+
x = global_pool_naflex(
869+
x,
870+
patch_valid,
871+
pool_type=pool_type,
872+
num_prefix_tokens=self.num_prefix_tokens,
873+
)
852874
return x
853875

854876
def forward_head(
@@ -897,14 +919,11 @@ def forward(
897919
patches = x
898920

899921
# Create attention mask if patch_type is provided
900-
if patch_valid is not None:
901-
attn_mask = create_attention_mask(
902-
patch_valid,
903-
num_prefix_tokens=self.num_prefix_tokens,
904-
dtype=patches.dtype
905-
)
906-
else:
907-
attn_mask = None
922+
attn_mask = create_attention_mask(
923+
patch_valid,
924+
num_prefix_tokens=self.num_prefix_tokens,
925+
dtype=patches.dtype,
926+
)
908927

909928
# Forward features with mask
910929
x = self.forward_features(

0 commit comments

Comments
 (0)