Skip to content

Commit f69a9bc

Browse files
jenksptjongwook
andauthored
Remove inefficient computation from AttentionPool2d Module (#271)
* fix inefficient attention computation * remove erroneous formatting * simplified flatten Co-authored-by: Jong Wook Kim <[email protected]>
1 parent 4d120f3 commit f69a9bc

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

clip/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim:
6666
self.num_heads = num_heads
6767

6868
def forward(self, x):
69-
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
69+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
7070
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
7171
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
7272
x, _ = F.multi_head_attention_forward(
73-
query=x, key=x, value=x,
73+
query=x[:1], key=x, value=x,
7474
embed_dim_to_check=x.shape[-1],
7575
num_heads=self.num_heads,
7676
q_proj_weight=self.q_proj.weight,
@@ -88,8 +88,7 @@ def forward(self, x):
8888
training=self.training,
8989
need_weights=False
9090
)
91-
92-
return x[0]
91+
return x.squeeze(0)
9392

9493

9594
class ModifiedResNet(nn.Module):

0 commit comments

Comments
 (0)