2727from timm .layers import AttentionPoolLatent , Mlp , to_2tuple , get_act_layer , get_norm_layer , LayerType , _assert
2828from timm .models ._builder import build_model_with_cfg
2929from timm .models ._features import feature_take_indices
30+ from timm .models ._features_fx import register_notrace_function , register_notrace_module
3031from timm .models ._registry import register_model , generate_default_cfgs
3132from timm .models ._manipulate import checkpoint_seq , named_apply
3233
@@ -55,6 +56,7 @@ def batch_patchify(
5556 return patches , (nh , nw )
5657
5758
59+ @register_notrace_module
5860class FlexEmbeds (nn .Module ):
5961 """ Na(Flex) Embedding module for Vision Transformers
6062
@@ -216,18 +218,18 @@ def forward(self, x, patch_coord=None, patch_valid=None):
216218 naflex_grid_sizes : Optional [List [Tuple [int , int ]]] = None
217219 grid_size : Optional [Tuple [int , int ]] = None
218220
221+ B = x .shape [0 ]
219222 if self .is_linear :
220223 # Linear embedding path, works with NaFlex mode or standard 2D mode
221- B = x .shape [0 ]
222- if x .ndim == 3 :
223- # pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
224- _assert (patch_coord is not None , 'patch_coord must not be None in NaFlex mode' )
225-
224+ if patch_coord is not None :
225+ _assert (x .ndim == 3 , 'Expecting patchified input with ndim == 3' )
226+ # Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
226227 # Calculate the appropriate grid size from coords
227228 max_y = patch_coord [:, :, 0 ].max (dim = 1 )[0 ] + 1
228229 max_x = patch_coord [:, :, 1 ].max (dim = 1 )[0 ] + 1
229230 naflex_grid_sizes = [(h .item (), w .item ()) for h , w in zip (max_y , max_x )]
230231 else :
232+ _assert (x .ndim == 4 , 'Expecting 2D image input with input ndim == 4' )
231233 x , grid_size = batch_patchify (x , self .patch_size , pad = self .dynamic_img_pad )
232234
233235 if self .norm_input is not None :
@@ -252,7 +254,7 @@ def forward(self, x, patch_coord=None, patch_valid=None):
252254 x = self .norm_proj (x )
253255
254256 if self .pos_embed_type == 'learned' :
255- if naflex_grid_sizes :
257+ if naflex_grid_sizes is not None :
256258 self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
257259 else :
258260 self ._apply_learned_pos_embed (x , grid_size = grid_size )
@@ -336,6 +338,7 @@ def _apply_learned_pos_embed(
336338 x .add_ (pos_embed )
337339
338340
341+ @register_notrace_function
339342def create_attention_mask (
340343 patch_valid : Optional [torch .Tensor ],
341344 num_prefix_tokens : int = 0 ,
@@ -367,6 +370,8 @@ def create_attention_mask(
367370
368371 return mask_float
369372
373+
374+ @register_notrace_function
370375def create_attention_mask2 (
371376 patch_valid : Optional [torch .Tensor ],
372377 num_prefix_tokens : int = 0 ,
@@ -404,6 +409,7 @@ def create_attention_mask2(
404409 return mask_float
405410
406411
412+ @register_notrace_function
407413def create_pool_mask (
408414 patch_valid : Optional [torch .Tensor ],
409415 dtype : torch .dtype = torch .float32 ,
0 commit comments