Skip to content

Commit 598f322

Browse files
committed
Resolve torchscript issue
1 parent 6f452af commit 598f322

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

timm/models/eva.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def __init__(
550550
if use_rot_pos_emb:
551551
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
552552
if rope_mixed_mode:
553+
self.rope_mixed = True
553554
# Mixed mode to supports depth-dependent frequencies
554555
self.rope = RotaryEmbeddingMixed(
555556
dim=embed_dim,
@@ -560,6 +561,7 @@ def __init__(
560561
grid_indexing=rope_grid_indexing,
561562
)
562563
else:
564+
self.rope_mixed = False
563565
self.rope = RotaryEmbeddingCat(
564566
dim=embed_dim // num_heads,
565567
temperature=rope_temperature,
@@ -570,6 +572,7 @@ def __init__(
570572
grid_indexing=rope_grid_indexing,
571573
)
572574
else:
575+
self.rope_mixed = False
573576
self.rope = None
574577

575578
self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity()
@@ -770,7 +773,7 @@ def forward_intermediates(
770773
else:
771774
blocks = self.blocks[:max_index + 1]
772775
# Handle depth-dependent embeddings for mixed mode
773-
if rot_pos_embed is not None and isinstance(self.rope, RotaryEmbeddingMixed):
776+
if self.rope_mixed and rot_pos_embed is not None:
774777
for i, blk in enumerate(blocks):
775778
if self.grad_checkpointing and not torch.jit.is_scripting():
776779
x = checkpoint(blk, x, rope=rot_pos_embed[i])
@@ -847,7 +850,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
847850
x = self.norm_pre(x)
848851

849852
# Handle depth-dependent embeddings for mixed mode
850-
if rot_pos_embed is not None and isinstance(self.rope, RotaryEmbeddingMixed):
853+
if self.rope_mixed and rot_pos_embed is not None:
851854
# rot_pos_embed has shape (depth, H*W, dim) for mixed mode
852855
for i, blk in enumerate(self.blocks):
853856
if self.grad_checkpointing and not torch.jit.is_scripting():

0 commit comments

Comments
 (0)