Skip to content

Commit 939d4ab

Browse files
authored
Fixed CUDA/CPU drop_connect bug
1 parent 2842a86 commit 939d4ab

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

efficientnet_pytorch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def drop_connect(inputs, p, training):
6969
batch_size = inputs.shape[0]
7070
keep_prob = 1 - p
7171
random_tensor = keep_prob
72-
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype) # uniform [0,1)
72+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
7373
binary_tensor = torch.floor(random_tensor)
7474
output = inputs / keep_prob * binary_tensor
7575
return output

0 commit comments

Comments
 (0)