Skip to content

Commit 24963e5

Browse files
committed
add clipping gradient to model utils
1 parent 1dc4bda commit 24963e5

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

ngclearn/utils/model_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,14 @@ def drop_out(dkey, data, rate=0.0):
675675
output = data * mask
676676
return output, mask
677677

678+
@jit
679+
def clip(x, min_val, max_val):
680+
return jnp.clip(x, min_val, max_val)
681+
682+
@jit
683+
def d_clip(x, min_val, max_val):
684+
return jnp.where((x < min_val) | (x > max_val), 0.0, 1.0)
685+
678686

679687
def scanner(fn):
680688
"""

0 commit comments

Comments
 (0)