Skip to content

Commit 32d2037

Browse files
author
Kent Sommer
committed
Fixed other occurances related to 47dd60d
1 parent 47dd60d commit 32d2037

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(config, n_domains=100, max_obs=30,
8383
X_in, S1_in, S2_in = Variable(X_in), Variable(S1_in), Variable(S2_in)
8484
# Forward pass in our neural net
8585
_, predictions = vin(X_in, S1_in, S2_in, config)
86-
_, indices = torch.max(predictions.cpu(),1)
86+
_, indices = torch.max(predictions.cpu(), 1, keepdim=True)
8787
a = indices.data.numpy()[0][0]
8888
# Transform prediction to indices
8989
s = G.map_ind_to_state(pred_traj[j-1, 0], pred_traj[j-1, 1])

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test(net, testloader, config):
7373
# Forward pass
7474
outputs, predictions = net(X, S1, S2, config)
7575
# Select actions with max scores(logits)
76-
_, predicted = torch.max(outputs, dim=1)
76+
_, predicted = torch.max(outputs, dim=1, keepdim=True)
7777
# Unwrap autograd.Variable to Tensor
7878
predicted = predicted.data
7979
# Compute test accuracy

0 commit comments

Comments
 (0)