Skip to content

Commit 24485ed

Browse files
Fixed failing GPU test.
1 parent 9d34978 commit 24485ed

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from mlagents.torch_utils import torch
2+
from mlagents.torch_utils import torch, default_device
33
import numpy as np
44

55
from mlagents.trainers.torch_entities.utils import ModelUtils
@@ -217,7 +217,7 @@ def test_predict_minimum_training():
217217
argmin = argmin.squeeze()
218218
argmin = argmin.detach()
219219
sliced_oh = onehots[:, : num + 1]
220-
inp = torch.cat([inp, sliced_oh], dim=2)
220+
inp = torch.cat([inp, sliced_oh.to(default_device())], dim=2)
221221

222222
embeddings = entity_embedding(inp, inp)
223223
masks = get_zero_entities_mask([inp])

0 commit comments

Comments
 (0)