Skip to content

Commit eca8a1c

Browse files
committed
add surrogate gradients with class object
1 parent 2d9ab53 commit eca8a1c

File tree

2 files changed

+208
-7
lines changed

2 files changed

+208
-7
lines changed

brainpy/_src/math/surrogate/one_input.py

Lines changed: 173 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
]
3535

3636

37+
class Sigmoid:
38+
def __init__(self, alpha=4., orgin=False):
39+
self.alpha = alpha
40+
self.orgin = orgin
41+
42+
def __call__(self, x: Union[jax.Array, Array]):
43+
return sigmoid(x, alpha=self.alpha, origin=self.origin)
44+
3745

3846
@vjp_custom(['x'], dict(alpha=4., origin=False), dict(origin=[True, False]))
3947
def sigmoid(
@@ -105,6 +113,15 @@ def grad(dz):
105113
return z, grad
106114

107115

116+
class PiecewiseQuadratic:
117+
def __init__(self, alpha=1., origin=False):
118+
self.alpha = alpha
119+
self.origin = origin
120+
121+
def __call__(self, x: Union[jax.Array, Array]):
122+
return piecewise_quadratic(x, alpha=self.alpha, origin=self.origin)
123+
124+
108125
@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False]))
109126
def piecewise_quadratic(
110127
x: Union[jax.Array, Array],
@@ -195,6 +212,15 @@ def grad(dz):
195212
return z, grad
196213

197214

215+
class PiecewiseExp:
216+
def __init__(self, alpha=1., origin=False):
217+
self.alpha = alpha
218+
self.origin = origin
219+
220+
def __call__(self, x: Union[jax.Array, Array]):
221+
return piecewise_exp(x, alpha=self.alpha, origin=self.origin)
222+
223+
198224
@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False]))
199225
def piecewise_exp(
200226
x: Union[jax.Array, Array],
@@ -271,6 +297,15 @@ def grad(dz):
271297
return z, grad
272298

273299

300+
class SoftSign:
301+
def __init__(self, alpha=1., origin=False):
302+
self.alpha = alpha
303+
self.origin = origin
304+
305+
def __call__(self, x: Union[jax.Array, Array]):
306+
return soft_sign(x, alpha=self.alpha, origin=self.origin)
307+
308+
274309
@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False]))
275310
def soft_sign(
276311
x: Union[jax.Array, Array],
@@ -342,6 +377,15 @@ def grad(dz):
342377
return z, grad
343378

344379

380+
class Arctan:
381+
def __init__(self, alpha=1., origin=False):
382+
self.alpha = alpha
383+
self.origin = origin
384+
385+
def __call__(self, x: Union[jax.Array, Array]):
386+
return arctan(x, alpha=self.alpha, origin=self.origin)
387+
388+
345389
@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False]))
346390
def arctan(
347391
x: Union[jax.Array, Array],
@@ -412,6 +456,15 @@ def grad(dz):
412456
return z, grad
413457

414458

459+
class NonzeroSignLog:
460+
def __init__(self, alpha=1., origin=False):
461+
self.alpha = alpha
462+
self.origin = origin
463+
464+
def __call__(self, x: Union[jax.Array, Array]):
465+
return nonzero_sign_log(x, alpha=self.alpha, origin=self.origin)
466+
467+
415468
@vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]})
416469
def nonzero_sign_log(
417470
x: Union[jax.Array, Array],
@@ -495,6 +548,15 @@ def grad(dz):
495548
return z, grad
496549

497550

551+
class ERF:
552+
def __init__(self, alpha=1., origin=False):
553+
self.alpha = alpha
554+
self.origin = origin
555+
556+
def __call__(self, x: Union[jax.Array, Array]):
557+
return erf(x, alpha=self.alpha, origin=self.origin)
558+
559+
498560
@vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]})
499561
def erf(
500562
x: Union[jax.Array, Array],
@@ -569,12 +631,22 @@ def erf(
569631
z = jnp.asarray(x >= 0, dtype=x.dtype)
570632

571633
def grad(dz):
572-
dx = (alpha / math.sqrt(math.pi)) * jnp.exp(-math.pow(alpha, 2) * x * x)
634+
dx = (alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(alpha, 2) * x * x)
573635
return dx * as_jax(dz), None
574636

575637
return z, grad
576638

577639

640+
class PiecewiseLeakyRelu:
641+
def __init__(self, c=0.01, w=1., origin=False):
642+
self.c = c
643+
self.w = w
644+
self.origin = origin
645+
646+
def __call__(self, x: Union[jax.Array, Array]):
647+
return piecewise_leaky_relu(x, c=self.c, w=self.w, origin=self.origin)
648+
649+
578650
@vjp_custom(['x'], dict(c=0.01, w=1., origin=False), statics={'origin': [True, False]})
579651
def piecewise_leaky_relu(
580652
x: Union[jax.Array, Array],
@@ -673,6 +745,16 @@ def grad(dz):
673745
return z, grad
674746

675747

748+
class SquarewaveFourierSeries:
749+
def __init__(self, n=2, t_period=8., origin=False):
750+
self.n = n
751+
self.t_period = t_period
752+
self.origin = origin
753+
754+
def __call__(self, x: Union[jax.Array, Array]):
755+
return squarewave_fourier_series(x, self.n, self.t_period, self.origin)
756+
757+
676758
@vjp_custom(['x'], dict(n=2, t_period=8., origin=False), statics={'origin': [True, False]})
677759
def squarewave_fourier_series(
678760
x: Union[jax.Array, Array],
@@ -732,13 +814,13 @@ def squarewave_fourier_series(
732814
The spiking state.
733815
734816
"""
735-
w = math.pi * 2. / t_period
817+
w = jnp.pi * 2. / t_period
736818
if origin:
737819
ret = jnp.sin(w * x)
738820
for i in range(2, n):
739821
c = (2 * i - 1.)
740822
ret += jnp.sin(c * w * x) / c
741-
z = 0.5 + 2. / math.pi * ret
823+
z = 0.5 + 2. / jnp.pi * ret
742824
else:
743825
z = jnp.asarray(x >= 0, dtype=x.dtype)
744826

@@ -752,6 +834,17 @@ def grad(dz):
752834
return z, grad
753835

754836

837+
class S2NN:
838+
def __init__(self, alpha=4., beta=1., epsilon=1e-8, origin=False):
839+
self.alpha = alpha
840+
self.beta = beta
841+
self.epsilon = epsilon
842+
self.origin = origin
843+
844+
def __call__(self, x: Union[jax.Array, Array], ):
845+
return s2nn(x, self.alpha, self.beta, self.epsilon, self.origin)
846+
847+
755848
@vjp_custom(['x'],
756849
defaults=dict(alpha=4., beta=1., epsilon=1e-8, origin=False),
757850
statics={'origin': [True, False]})
@@ -844,6 +937,15 @@ def grad(dz):
844937
return z, grad
845938

846939

940+
class QPseudoSpike:
941+
def __init__(self, alpha=2., origin=False):
942+
self.alpha = alpha
943+
self.origin = origin
944+
945+
def __call__(self, x: Union[jax.Array, Array]):
946+
return q_pseudo_spike(x, self.alpha, self.origin)
947+
948+
847949
@vjp_custom(['x'],
848950
dict(alpha=2., origin=False),
849951
statics={'origin': [True, False]})
@@ -925,6 +1027,16 @@ def grad(dz):
9251027
return z, grad
9261028

9271029

1030+
class LeakyRelu:
1031+
def __init__(self, alpha=0.1, beta=1., origin=False):
1032+
self.alpha = alpha
1033+
self.beta = beta
1034+
self.origin = origin
1035+
1036+
def __call__(self, x: Union[jax.Array, Array]):
1037+
return leaky_relu(x, self.alpha, self.beta, self.origin)
1038+
1039+
9281040
@vjp_custom(['x'],
9291041
dict(alpha=0.1, beta=1., origin=False),
9301042
statics={'origin': [True, False]})
@@ -1006,6 +1118,15 @@ def grad(dz):
10061118
return z, grad
10071119

10081120

1121+
class LogTailedRelu:
1122+
def __init__(self, alpha=0., origin=False):
1123+
self.alpha = alpha
1124+
self.origin = origin
1125+
1126+
def __call__(self, x: Union[jax.Array, Array]):
1127+
return log_tailed_relu(x, self.alpha, self.origin)
1128+
1129+
10091130
@vjp_custom(['x'],
10101131
dict(alpha=0., origin=False),
10111132
statics={'origin': [True, False]})
@@ -1098,6 +1219,15 @@ def grad(dz):
10981219
return z, grad
10991220

11001221

1222+
class ReluGrad:
1223+
def __init__(self, alpha=0.3, width=1.):
1224+
self.alpha = alpha
1225+
self.width = width
1226+
1227+
def __call__(self, x: Union[jax.Array, Array]):
1228+
return relu_grad(x, self.alpha, self.width)
1229+
1230+
11011231
@vjp_custom(['x'], dict(alpha=0.3, width=1.))
11021232
def relu_grad(
11031233
x: Union[jax.Array, Array],
@@ -1163,6 +1293,15 @@ def grad(dz):
11631293
return z, grad
11641294

11651295

1296+
class GaussianGrad:
1297+
def __init__(self, sigma=0.5, alpha=0.5):
1298+
self.sigma = sigma
1299+
self.alpha = alpha
1300+
1301+
def __call__(self, x: Union[jax.Array, Array]):
1302+
return gaussian_grad(x, self.sigma, self.alpha)
1303+
1304+
11661305
@vjp_custom(['x'], dict(sigma=0.5, alpha=0.5))
11671306
def gaussian_grad(
11681307
x: Union[jax.Array, Array],
@@ -1221,12 +1360,23 @@ def gaussian_grad(
12211360
z = jnp.asarray(x >= 0, dtype=x.dtype)
12221361

12231362
def grad(dz):
1224-
dx = jnp.exp(-(x ** 2) / 2 * math.pow(sigma, 2)) / (math.sqrt(2 * math.pi) * sigma)
1363+
dx = jnp.exp(-(x ** 2) / 2 * jnp.power(sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * sigma)
12251364
return alpha * dx * as_jax(dz), None, None
12261365

12271366
return z, grad
12281367

12291368

1369+
class MultiGaussianGrad:
1370+
def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
1371+
self.h = h
1372+
self.s = s
1373+
self.sigma = sigma
1374+
self.scale = scale
1375+
1376+
def __call__(self, x: Union[jax.Array, Array]):
1377+
return multi_gaussian_grad(x, self.h, self.s, self.sigma, self.scale)
1378+
1379+
12301380
@vjp_custom(['x'], dict(h=0.15, s=6.0, sigma=0.5, scale=0.5))
12311381
def multi_gaussian_grad(
12321382
x: Union[jax.Array, Array],
@@ -1294,15 +1444,23 @@ def multi_gaussian_grad(
12941444
z = jnp.asarray(x >= 0, dtype=x.dtype)
12951445

12961446
def grad(dz):
1297-
g1 = jnp.exp(-x ** 2 / (2 * math.pow(sigma, 2))) / (math.sqrt(2 * math.pi) * sigma)
1298-
g2 = jnp.exp(-(x - sigma) ** 2 / (2 * math.pow(s * sigma, 2))) / (math.sqrt(2 * math.pi) * s * sigma)
1299-
g3 = jnp.exp(-(x + sigma) ** 2 / (2 * math.pow(s * sigma, 2))) / (math.sqrt(2 * math.pi) * s * sigma)
1447+
g1 = jnp.exp(-x ** 2 / (2 * jnp.power(sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * sigma)
1448+
g2 = jnp.exp(-(x - sigma) ** 2 / (2 * jnp.power(s * sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * s * sigma)
1449+
g3 = jnp.exp(-(x + sigma) ** 2 / (2 * jnp.power(s * sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * s * sigma)
13001450
dx = g1 * (1. + h) - g2 * h - g3 * h
13011451
return scale * dx * as_jax(dz), None, None, None, None
13021452

13031453
return z, grad
13041454

13051455

1456+
class InvSquareGrad:
1457+
def __init__(self, alpha=100.):
1458+
self.alpha = alpha
1459+
1460+
def __call__(self, x: Union[jax.Array, Array]):
1461+
return inv_square_grad(x, self.alpha)
1462+
1463+
13061464
@vjp_custom(['x'], dict(alpha=100.))
13071465
def inv_square_grad(
13081466
x: Union[jax.Array, Array],
@@ -1360,6 +1518,14 @@ def grad(dz):
13601518
return z, grad
13611519

13621520

1521+
class SlayerGrad:
1522+
def __init__(self, alpha=1.):
1523+
self.alpha = alpha
1524+
1525+
def __call__(self, x: Union[jax.Array, Array]):
1526+
return slayer_grad(x, self.alpha)
1527+
1528+
13631529
@vjp_custom(['x'], dict(alpha=1.))
13641530
def slayer_grad(
13651531
x: Union[jax.Array, Array],

brainpy/math/surrogate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,58 @@
55
# vjp_custom as vjp_custom
66
# )
77
from brainpy._src.math.surrogate.one_input import (
8+
Sigmoid,
89
sigmoid as sigmoid,
10+
11+
PiecewiseQuadratic,
912
piecewise_quadratic as piecewise_quadratic,
13+
14+
PiecewiseExp,
1015
piecewise_exp as piecewise_exp,
16+
17+
SoftSign,
1118
soft_sign as soft_sign,
19+
20+
Arctan,
1221
arctan as arctan,
22+
23+
NonzeroSignLog,
1324
nonzero_sign_log as nonzero_sign_log,
25+
26+
ERF,
1427
erf as erf,
28+
29+
PiecewiseLeakyRelu,
1530
piecewise_leaky_relu as piecewise_leaky_relu,
31+
32+
SquarewaveFourierSeries,
1633
squarewave_fourier_series as squarewave_fourier_series,
34+
35+
S2NN,
1736
s2nn as s2nn,
37+
38+
QPseudoSpike,
1839
q_pseudo_spike as q_pseudo_spike,
40+
41+
LeakyRelu,
1942
leaky_relu as leaky_relu,
43+
44+
LogTailedRelu,
2045
log_tailed_relu as log_tailed_relu,
46+
47+
ReluGrad,
2148
relu_grad as relu_grad,
49+
50+
GaussianGrad,
2251
gaussian_grad as gaussian_grad,
52+
53+
InvSquareGrad,
2354
inv_square_grad as inv_square_grad,
55+
56+
MultiGaussianGrad,
2457
multi_gaussian_grad as multi_gaussian_grad,
58+
59+
SlayerGrad,
2560
slayer_grad as slayer_grad,
2661
)
2762
from brainpy._src.math.surrogate.two_inputs import (

0 commit comments

Comments
 (0)