Skip to content

Commit 9885d69

Browse files
committed
Support set_input_size() in EVA models
1 parent 19f2bfb commit 9885d69

File tree

2 files changed

+101
-29
lines changed

2 files changed

+101
-29
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def __init__(
354354
self.dim = dim
355355
self.max_res = max_res
356356
self.temperature = temperature
357+
self.linear_bands = linear_bands
357358
self.in_pixels = in_pixels
358359
self.feat_shape = feat_shape
359360
self.ref_feat_shape = ref_feat_shape
@@ -383,17 +384,7 @@ def __init__(
383384
self.pos_embed_cos = None
384385
else:
385386
# cache full sin/cos embeddings if shape provided up front
386-
emb_sin, emb_cos = build_rotary_pos_embed(
387-
feat_shape=feat_shape,
388-
dim=dim,
389-
max_res=max_res,
390-
linear_bands=linear_bands,
391-
in_pixels=in_pixels,
392-
ref_feat_shape=self.ref_feat_shape,
393-
grid_offset=self.grid_offset,
394-
grid_indexing=self.grid_indexing,
395-
temperature=self.temperature,
396-
)
387+
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape)
397388
self.bands = None
398389
self.register_buffer(
399390
'pos_embed_sin',
@@ -406,6 +397,29 @@ def __init__(
406397
persistent=False,
407398
)
408399

400+
def _get_pos_embed_values(self, feat_shape: List[int]):
401+
emb_sin, emb_cos = build_rotary_pos_embed(
402+
feat_shape=feat_shape,
403+
dim=self.dim,
404+
max_res=self.max_res,
405+
temperature=self.temperature,
406+
linear_bands=self.linear_bands,
407+
in_pixels=self.in_pixels,
408+
ref_feat_shape=self.ref_feat_shape,
409+
grid_offset=self.grid_offset,
410+
grid_indexing=self.grid_indexing,
411+
)
412+
return emb_sin, emb_cos
413+
414+
def update_feat_shape(self, feat_shape: List[int]):
415+
if self.feat_shape is not None and feat_shape != self.feat_shape:
416+
# only update if feat_shape was set and different from previous value
417+
assert self.pos_embed is not None
418+
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape)
419+
self.pos_embed_sin = emb_sin.to(self.pos_embed_sin.device, self.pos_embed_sin.dtype)
420+
self.pos_embed_cos = emb_cos.to(self.pos_embed_cos.device, self.pos_embed_cos.dtype)
421+
self.feat_shape = feat_shape
422+
409423
def get_embed(self, shape: Optional[List[int]] = None):
410424
if shape is not None and self.bands is not None:
411425
# rebuild embeddings every call, use if target shape changes
@@ -453,6 +467,7 @@ def __init__(
453467
self.max_res = max_res
454468
self.temperature = temperature
455469
self.in_pixels = in_pixels
470+
self.linear_bands = linear_bands
456471
self.feat_shape = feat_shape
457472
self.ref_feat_shape = ref_feat_shape
458473
self.grid_offset = grid_offset
@@ -480,27 +495,40 @@ def __init__(
480495
self.pos_embed = None
481496
else:
482497
# cache full sin/cos embeddings if shape provided up front
483-
embeds = build_rotary_pos_embed(
484-
feat_shape=feat_shape,
485-
dim=dim,
486-
max_res=max_res,
487-
linear_bands=linear_bands,
488-
in_pixels=in_pixels,
489-
ref_feat_shape=self.ref_feat_shape,
490-
grid_offset=self.grid_offset,
491-
grid_indexing=self.grid_indexing,
492-
temperature=self.temperature,
493-
)
494498
self.bands = None
495499
self.register_buffer(
496500
'pos_embed',
497-
torch.cat(embeds, -1),
501+
self._get_pos_embed_values(feat_shape=feat_shape),
498502
persistent=False,
499503
)
500504

505+
def _get_pos_embed_values(self, feat_shape: List[int]):
506+
embeds = build_rotary_pos_embed(
507+
feat_shape=feat_shape,
508+
dim=self.dim,
509+
max_res=self.max_res,
510+
temperature=self.temperature,
511+
linear_bands=self.linear_bands,
512+
in_pixels=self.in_pixels,
513+
ref_feat_shape=self.ref_feat_shape,
514+
grid_offset=self.grid_offset,
515+
grid_indexing=self.grid_indexing,
516+
)
517+
return torch.cat(embeds, -1)
518+
519+
def update_feat_shape(self, feat_shape: List[int]):
520+
if self.feat_shape is not None and feat_shape != self.feat_shape:
521+
# only update if feat_shape was set and different from previous value
522+
assert self.pos_embed is not None
523+
self.pos_embed = self._get_pos_embed_values(feat_shape).to(
524+
device=self.pos_embed.device,
525+
dtype=self.pos_embed.dtype,
526+
)
527+
self.feat_shape = feat_shape
528+
501529
def get_embed(self, shape: Optional[List[int]] = None):
502530
if shape is not None and self.bands is not None:
503-
# rebuild embeddings every call, use if target shape changes
531+
# rebuild embeddings from cached bands every call, use if target shape changes
504532
embeds = build_rotary_pos_embed(
505533
shape,
506534
self.bands,
@@ -684,6 +712,7 @@ def __init__(
684712

685713
head_dim = dim // num_heads
686714
assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"
715+
687716
freqs = init_random_2d_freqs(
688717
head_dim,
689718
depth,
@@ -692,18 +721,32 @@ def __init__(
692721
rotate=True,
693722
) # (2, depth, num_heads, head_dim//2)
694723
self.freqs = nn.Parameter(freqs)
724+
695725
if feat_shape is not None:
696726
# cache pre-computed grid
697-
t_x, t_y = get_mixed_grid(
698-
feat_shape,
699-
grid_indexing=grid_indexing,
700-
device=self.freqs.device
701-
)
727+
t_x, t_y = self._get_grid_values(feat_shape)
702728
self.register_buffer('t_x', t_x, persistent=False)
703729
self.register_buffer('t_y', t_y, persistent=False)
704730
else:
705731
self.t_x = self.t_y = None
706732

733+
def _get_grid_values(self, feat_shape: Optional[List[int]]):
734+
t_x, t_y = get_mixed_grid(
735+
feat_shape,
736+
grid_indexing=self.grid_indexing,
737+
device=self.freqs.device
738+
)
739+
return t_x, t_y
740+
741+
def update_feat_shape(self, feat_shape: Optional[List[int]]):
742+
if self.feat_shape is not None and feat_shape != self.feat_shape:
743+
assert self.t_x is not None
744+
assert self.t_y is not None
745+
t_x, t_y = self._get_grid_values(feat_shape)
746+
self.t_x = t_x.to(self.t_x.device, self.t_x.dtype)
747+
self.t_y = t_y.to(self.t_y.device, self.t_y.dtype)
748+
self.feat_shape = feat_shape
749+
707750
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
708751
"""Generate rotary embeddings for the given spatial shape.
709752

timm/models/eva.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,35 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None)
723723
self.global_pool = global_pool
724724
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
725725

726+
def set_input_size(
727+
self,
728+
img_size: Optional[Tuple[int, int]] = None,
729+
patch_size: Optional[Tuple[int, int]] = None,
730+
) -> None:
731+
"""Update the input image resolution and patch size.
732+
733+
Args:
734+
img_size: New input resolution, if None current resolution is used.
735+
patch_size: New patch size, if None existing patch size is used.
736+
"""
737+
prev_grid_size = self.patch_embed.grid_size
738+
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
739+
740+
if self.pos_embed is not None:
741+
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
742+
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
743+
if num_new_tokens != self.pos_embed.shape[1]:
744+
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
745+
self.pos_embed,
746+
new_size=self.patch_embed.grid_size,
747+
old_size=prev_grid_size,
748+
num_prefix_tokens=num_prefix_tokens,
749+
verbose=True,
750+
))
751+
752+
if self.rope is not None:
753+
self.rope.update_feat_shape(self.patch_embed.grid_size)
754+
726755
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
727756
if self.dynamic_img_size:
728757
B, H, W, C = x.shape

0 commit comments

Comments
 (0)