Skip to content

Commit b4bb0f4

Browse files
committed
Exclude embeds module and mask attn functions from tracing
1 parent 13e0f3a commit b4bb0f4

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

timm/models/vision_transformer_flex.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
2828
from timm.models._builder import build_model_with_cfg
2929
from timm.models._features import feature_take_indices
30+
from timm.models._features_fx import register_notrace_function, register_notrace_module
3031
from timm.models._registry import register_model, generate_default_cfgs
3132
from timm.models._manipulate import checkpoint_seq, named_apply
3233

@@ -55,6 +56,7 @@ def batch_patchify(
5556
return patches, (nh, nw)
5657

5758

59+
@register_notrace_module
5860
class FlexEmbeds(nn.Module):
5961
""" Na(Flex) Embedding module for Vision Transformers
6062
@@ -216,18 +218,18 @@ def forward(self, x, patch_coord=None, patch_valid=None):
216218
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
217219
grid_size: Optional[Tuple[int, int]] = None
218220

221+
B = x.shape[0]
219222
if self.is_linear:
220223
# Linear embedding path, works with NaFlex mode or standard 2D mode
221-
B = x.shape[0]
222-
if x.ndim == 3:
223-
# pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
224-
_assert(patch_coord is not None, 'patch_coord must not be None in NaFlex mode')
225-
224+
if patch_coord is not None:
225+
_assert(x.ndim == 3, 'Expecting patchified input with ndim == 3')
226+
# Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
226227
# Calculate the appropriate grid size from coords
227228
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
228229
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
229230
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
230231
else:
232+
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
231233
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
232234

233235
if self.norm_input is not None:
@@ -252,7 +254,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
252254
x = self.norm_proj(x)
253255

254256
if self.pos_embed_type == 'learned':
255-
if naflex_grid_sizes:
257+
if naflex_grid_sizes is not None:
256258
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
257259
else:
258260
self._apply_learned_pos_embed(x, grid_size=grid_size)
@@ -336,6 +338,7 @@ def _apply_learned_pos_embed(
336338
x.add_(pos_embed)
337339

338340

341+
@register_notrace_function
339342
def create_attention_mask(
340343
patch_valid: Optional[torch.Tensor],
341344
num_prefix_tokens: int = 0,
@@ -367,6 +370,8 @@ def create_attention_mask(
367370

368371
return mask_float
369372

373+
374+
@register_notrace_function
370375
def create_attention_mask2(
371376
patch_valid: Optional[torch.Tensor],
372377
num_prefix_tokens: int = 0,
@@ -404,6 +409,7 @@ def create_attention_mask2(
404409
return mask_float
405410

406411

412+
@register_notrace_function
407413
def create_pool_mask(
408414
patch_valid: Optional[torch.Tensor],
409415
dtype: torch.dtype = torch.float32,

0 commit comments

Comments
 (0)