@@ -389,7 +389,29 @@ def __init__(
389389 nn .LayerNorm (dim ) if dual_patchnorm else None ,
390390 )
391391
392- self .axial_pos_emb = nn .Parameter (torch .randn (2 , patch_height_width , dim ) * 0.02 )
392+ # axial positional embeddings, parameterized by an MLP
393+
394+ pos_emb_dim = dim // 2
395+
396+ self .axial_pos_emb_height_mlp = nn .Sequential (
397+ Rearrange ('... -> ... 1' ),
398+ nn .Linear (1 , pos_emb_dim ),
399+ nn .SiLU (),
400+ nn .Linear (pos_emb_dim , pos_emb_dim ),
401+ nn .SiLU (),
402+ nn .Linear (pos_emb_dim , dim )
403+ )
404+
405+ self .axial_pos_emb_width_mlp = nn .Sequential (
406+ Rearrange ('... -> ... 1' ),
407+ nn .Linear (1 , pos_emb_dim ),
408+ nn .SiLU (),
409+ nn .Linear (pos_emb_dim , pos_emb_dim ),
410+ nn .SiLU (),
411+ nn .Linear (pos_emb_dim , dim )
412+ )
413+
414+ # nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)
393415
394416 self .to_pixels = nn .Sequential (
395417 LayerNorm (dim ),
@@ -414,6 +436,10 @@ def __init__(
414436
415437 self .blocks = nn .ModuleList ([RINBlock (dim , dim_latent = dim_latent , latent_self_attn_depth = latent_self_attn_depth , ** attn_kwargs ) for _ in range (depth )])
416438
439+ @property
440+ def device (self ):
441+ return next (self .parameters ()).device
442+
417443 def forward (
418444 self ,
419445 x ,
@@ -451,7 +477,9 @@ def forward(
451477
452478 patches = self .to_patches (x )
453479
454- pos_emb_h , pos_emb_w = self .axial_pos_emb
480+ height_range = width_range = torch .linspace (0. , 1. , steps = int (math .sqrt (patches .shape [- 2 ])), device = self .device )
481+ pos_emb_h , pos_emb_w = self .axial_pos_emb_height_mlp (height_range ), self .axial_pos_emb_width_mlp (width_range )
482+
455483 pos_emb = rearrange (pos_emb_h , 'i d -> i 1 d' ) + rearrange (pos_emb_w , 'j d -> 1 j d' )
456484 patches = patches + rearrange (pos_emb , 'i j d -> (i j) d' )
457485
0 commit comments