@@ -550,6 +550,7 @@ def __init__(
550550 if use_rot_pos_emb :
551551 ref_feat_shape = to_2tuple (ref_feat_shape ) if ref_feat_shape is not None else None
552552 if rope_mixed_mode :
553+ self .rope_mixed = True
553554 # Mixed mode to supports depth-dependent frequencies
554555 self .rope = RotaryEmbeddingMixed (
555556 dim = embed_dim ,
@@ -560,6 +561,7 @@ def __init__(
560561 grid_indexing = rope_grid_indexing ,
561562 )
562563 else :
564+ self .rope_mixed = False
563565 self .rope = RotaryEmbeddingCat (
564566 dim = embed_dim // num_heads ,
565567 temperature = rope_temperature ,
@@ -570,6 +572,7 @@ def __init__(
570572 grid_indexing = rope_grid_indexing ,
571573 )
572574 else :
575+ self .rope_mixed = False
573576 self .rope = None
574577
575578 self .norm_pre = norm_layer (embed_dim ) if activate_pre_norm else nn .Identity ()
@@ -770,7 +773,7 @@ def forward_intermediates(
770773 else :
771774 blocks = self .blocks [:max_index + 1 ]
772775 # Handle depth-dependent embeddings for mixed mode
773- if rot_pos_embed is not None and isinstance ( self . rope , RotaryEmbeddingMixed ) :
776+ if self . rope_mixed and rot_pos_embed is not None :
774777 for i , blk in enumerate (blocks ):
775778 if self .grad_checkpointing and not torch .jit .is_scripting ():
776779 x = checkpoint (blk , x , rope = rot_pos_embed [i ])
@@ -847,7 +850,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
847850 x = self .norm_pre (x )
848851
849852 # Handle depth-dependent embeddings for mixed mode
850- if rot_pos_embed is not None and isinstance ( self . rope , RotaryEmbeddingMixed ) :
853+ if self . rope_mixed and rot_pos_embed is not None :
851854 # rot_pos_embed has shape (depth, H*W, dim) for mixed mode
852855 for i , blk in enumerate (self .blocks ):
853856 if self .grad_checkpointing and not torch .jit .is_scripting ():
0 commit comments