We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2842a86 commit 939d4abCopy full SHA for 939d4ab
efficientnet_pytorch/utils.py
@@ -69,7 +69,7 @@ def drop_connect(inputs, p, training):
69
batch_size = inputs.shape[0]
70
keep_prob = 1 - p
71
random_tensor = keep_prob
72
- random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype) # uniform [0,1)
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
73
binary_tensor = torch.floor(random_tensor)
74
output = inputs / keep_prob * binary_tensor
75
return output
0 commit comments