From bf5dc211713a11f9033810146ee0833c9957635b Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Fri, 3 Sep 2021 17:57:45 +0900 Subject: [PATCH] skip .view() func in WindowAttention --- models/swin_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/swin_transformer.py b/models/swin_transformer.py index cfeb0f22..f1789290 100644 --- a/models/swin_transformer.py +++ b/models/swin_transformer.py @@ -99,7 +99,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -123,7 +123,7 @@ def forward(self, x, mask=None): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0)