Skip to content

Commit f38373f

Browse files
author
Alexander Ororbia
committed
generalized dropout in terms of shape
1 parent 1b7bff8 commit f38373f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

ngclearn/utils/model_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,7 @@ def drop_out(dkey, input, rate=0.0):
567567
Returns:
568568
output as well as binary mask
569569
"""
570-
eps = random.uniform(dkey, (input.shape[0],input.shape[1]),
571-
minval=0.0, maxval=1.0)
570+
eps = random.uniform(dkey, shape=input.shape, minval=0.0, maxval=1.0)
572571
mask = (eps <= (1.0 - rate)).astype(jnp.float32)
573572
mask = mask * (1.0 / (1.0 - rate)) ## apply inverted dropout scheme
574573
output = input * mask

0 commit comments

Comments
 (0)