@@ -280,41 +280,51 @@ def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
280280
281281 return x
282282
283+ #@torch.compiler.disable()
283284 def _apply_learned_naflex_pos_embed (
284285 self ,
285286 x : torch .Tensor ,
286287 naflex_grid_sizes : List [Tuple [int , int ]],
287288 ):
288- orig_h , orig_w = self .pos_embed .shape [1 :3 ]
289-
290- # Determine unique grid sizes
291- size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
292- for bi , (h , w ) in enumerate (naflex_grid_sizes ):
293- #k = h << 16 | w # FIXME can get jit compat with this
294- k = (h , w )
295- if not k in size_to_indices :
296- size_to_indices [k ] = [bi ]
297- else :
298- size_to_indices [k ].append (bi )
299-
300289 # Handle each batch element separately with its own grid size
290+ orig_h , orig_w = self .pos_embed .shape [1 :3 ]
301291 pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ).float () # B,C,H,W
302- for k , batch_indices in size_to_indices .items ():
303- h , w = k
304- #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
305- # Interpolate only once for this (h, w)
306- if (h == orig_h ) and (w == orig_w ):
292+
293+ def _interp (_size ):
294+ if (_size [0 ] == orig_h ) and (_size [1 ] == orig_w ):
307295 pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
308296 else :
309297 pos_embed_flat = F .interpolate (
310298 pos_embed_nchw ,
311- size = ( h , w ) ,
299+ size = _size ,
312300 mode = self .pos_embed_interp_mode ,
313301 align_corners = False ,
314302 antialias = True ,
315303 ).flatten (2 ).transpose (1 , 2 )
316- pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
304+ return pos_embed_flat .to (dtype = x .dtype )
305+
306+ # FIXME leaving alternative code commented here for now for comparisons
307+ # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
308+ # for i, s in enumerate(naflex_grid_sizes):
309+ # if s in pos_embed_cache:
310+ # pos_embed_flat = pos_embed_cache[s]
311+ # else:
312+ # pos_embed_flat = _interp(s)
313+ # pos_embed_cache[s] = pos_embed_flat
314+ #
315+ # seq_len = min(x.shape[1], pos_embed_flat.shape[1])
316+ # x[i, :seq_len] += pos_embed_flat[0, :seq_len]
317317
318+ # Determine unique grid sizes
319+ size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
320+ for bi , k in enumerate (naflex_grid_sizes ):
321+ # k = h << 16 | w # FIXME can get jit compat with this
322+ size_to_indices .setdefault (k , []).append (bi )
323+
324+ for k , batch_indices in size_to_indices .items ():
325+ # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
326+ # Interpolate only once for this (h, w)
327+ pos_embed_flat = _interp (k )
318328 seq_len = min (x .shape [1 ], pos_embed_flat .shape [1 ])
319329 x [:, :seq_len ].index_add_ (
320330 0 ,
@@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
10151025
10161026
10171027default_cfgs = generate_default_cfgs ({
1018- 'vit_naflex_base_patch16' : _cfg (),
10191028 'vit_naflex_base_patch16_gap' : _cfg (),
10201029 'vit_naflex_base_patch16_map' : _cfg (),
10211030
@@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
10501059 return model
10511060
10521061
1053- @register_model
1054- def vit_naflex_mediumd_patch16_reg4_gap (pretrained = False , ** kwargs ):
1055- """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1056- """
1057- model_args = dict (
1058- patch_size = 16 , embed_dim = 512 , depth = 20 , num_heads = 8 , init_values = 1e-5 ,
1059- global_pool = 'avg' , class_token = False , reg_tokens = 4 , fc_norm = True , ** kwargs )
1060- model = _create_vision_transformer_flex (
1061- 'vit_naflex_mediumd_patch16_reg4_gap' , pretrained = pretrained , ** model_args )
1062- return model
1063-
1064-
1065- @register_model
1066- def vit_naflex_base_patch16 (pretrained = False , ** kwargs ):
1067- """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1068-
1069- This model supports:
1070- 1. Variable aspect ratios and resolutions via patch coordinates
1071- 2. Position embedding interpolation for arbitrary grid sizes
1072- 3. Explicit patch coordinates and valid token masking
1073- """
1074- model_args = dict (
1075- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
1076- model = _create_vision_transformer_flex (
1077- 'vit_naflex_base_patch16' , pretrained = pretrained , ** model_args )
1078- return model
1079-
1080-
10811062@register_model
10821063def vit_naflex_base_patch16_gap (pretrained = False , ** kwargs ):
10831064 """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
10841065 """
10851066 model_args = dict (
1086- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
1087- global_pool = 'avg' , class_token = False , reg_tokens = 4 , ** kwargs )
1067+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 ,
1068+ global_pool = 'avg' , class_token = False , reg_tokens = 4 , fc_norm = True , ** kwargs )
10881069 model = _create_vision_transformer_flex (
1089- 'vit_naflex_base_patch16_gap' , pretrained = pretrained , ** model_args )
1070+ 'vit_naflex_base_patch16_gap' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
10901071 return model
10911072
10921073
@@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
10951076 """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
10961077 """
10971078 model_args = dict (
1098- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , global_pool = 'map' , ** kwargs )
1079+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 ,
1080+ global_pool = 'map' , reg_tokens = 1 )
10991081 model = _create_vision_transformer_flex (
1100- 'vit_naflex_base_patch16_map' , pretrained = pretrained , ** model_args )
1082+ 'vit_naflex_base_patch16_map' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
11011083 return model
11021084
11031085
@@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
11121094 """
11131095 model_args = dict (
11141096 patch_size = 16 , embed_dim = 832 , depth = 21 , num_heads = 13 , mlp_ratio = 34 / 13 , init_values = 1e-5 ,
1115- qkv_bias = False , class_token = False , reg_tokens = 1 , global_pool = 'avg' , fc_norm = True , ** kwargs )
1097+ qkv_bias = False , class_token = False , reg_tokens = 1 , global_pool = 'avg' , fc_norm = True )
11161098 model = _create_vision_transformer_flex (
1117- 'vit_naflex_so150m2_patch16_reg1_gap' , pretrained = pretrained , ** model_args )
1099+ 'vit_naflex_so150m2_patch16_reg1_gap' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
11181100 return model
11191101
11201102
@@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
11231105 """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
11241106 ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
11251107 """
1126- model_args = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , global_pool = 'token' , class_token = True , pos_embed_grid_size = (14 , 14 ))
1108+ model_args = dict (
1109+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
1110+ global_pool = 'token' , class_token = True , pos_embed_grid_size = (14 , 14 ))
11271111 model = _create_vision_transformer_flex ('vit_naflex_base_patch16' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
11281112 return model
0 commit comments