Skip to content

Commit 47dd60d

Browse files
author
Kent Sommer
committed
Update due to breaking changes in Pytorch 0.2
1 parent 984d04b commit 47dd60d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def forward(self, X, S1, S2, config):
3939
h = self.h(X)
4040
r = self.r(h)
4141
q = self.q(r)
42-
v, _ = torch.max(q, dim=1)
42+
v, _ = torch.max(q, dim=1, keepdim=True)
4343
for i in range(0, config.k - 1):
4444
q = F.conv2d(torch.cat([r, v], 1),
4545
torch.cat([self.q.weight, self.w], 1),
4646
stride=1,
4747
padding=1)
48-
v, _ = torch.max(q, dim=1)
48+
v, _ = torch.max(q, dim=1, keepdim=True)
4949

5050
q = F.conv2d(torch.cat([r, v], 1),
5151
torch.cat([self.q.weight, self.w], 1),

0 commit comments

Comments
 (0)