We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9d34978 commit 24485edCopy full SHA for 24485ed
ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py
@@ -1,5 +1,5 @@
1
import pytest
2
-from mlagents.torch_utils import torch
+from mlagents.torch_utils import torch, default_device
3
import numpy as np
4
5
from mlagents.trainers.torch_entities.utils import ModelUtils
@@ -217,7 +217,7 @@ def test_predict_minimum_training():
217
argmin = argmin.squeeze()
218
argmin = argmin.detach()
219
sliced_oh = onehots[:, : num + 1]
220
- inp = torch.cat([inp, sliced_oh], dim=2)
+ inp = torch.cat([inp, sliced_oh.to(default_device())], dim=2)
221
222
embeddings = entity_embedding(inp, inp)
223
masks = get_zero_entities_mask([inp])
0 commit comments