@@ -445,16 +445,16 @@ def __init__(self,
445445 self .num_windows = int ((input_resolution // window_size ) * (input_resolution // window_size ))
446446
447447 def forward (self , x , q_global ):
448- B , H , W , C = x .shape
449- shortcut = x
450- x = self .norm1 (x )
451- x_windows = window_partition (x , self .window_size )
452- x_windows = x_windows .view (- 1 , self .window_size * self .window_size , C )
453- attn_windows = self .attn (x_windows , q_global )
454- x = window_reverse (attn_windows , self .window_size , H , W )
455- x = shortcut + self .drop_path (self .gamma1 * x )
456- x = x + self .drop_path (self .gamma2 * self .mlp (self .norm2 (x )))
457- return x
448+ B , H , W , C = x .shape
449+ shortcut = x
450+ x = self .norm1 (x )
451+ x_windows = window_partition (x , self .window_size )
452+ x_windows = x_windows .view (- 1 , self .window_size * self .window_size , C )
453+ attn_windows = self .attn (x_windows , q_global )
454+ x = window_reverse (attn_windows , self .window_size , H , W )
455+ x = shortcut + self .drop_path (self .gamma1 * x )
456+ x = x + self .drop_path (self .gamma2 * self .mlp (self .norm2 (x )))
457+ return x
458458
459459
460460class GlobalQueryGen (nn .Module ):
@@ -474,6 +474,9 @@ def __init__(self,
474474 input_resolution: input image resolution.
475475 window_size: window size.
476476 num_heads: number of heads.
477+
478+ For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at
479+ down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details.
477480 """
478481
479482 super ().__init__ ()
0 commit comments