3434]
3535
3636
37+ class Sigmoid :
38+ def __init__ (self , alpha = 4. , orgin = False ):
39+ self .alpha = alpha
40+ self .orgin = orgin
41+
42+ def __call__ (self , x : Union [jax .Array , Array ]):
43+ return sigmoid (x , alpha = self .alpha , origin = self .origin )
44+
3745
3846@vjp_custom (['x' ], dict (alpha = 4. , origin = False ), dict (origin = [True , False ]))
3947def sigmoid (
@@ -105,6 +113,15 @@ def grad(dz):
105113 return z , grad
106114
107115
116+ class PiecewiseQuadratic :
117+ def __init__ (self , alpha = 1. , origin = False ):
118+ self .alpha = alpha
119+ self .origin = origin
120+
121+ def __call__ (self , x : Union [jax .Array , Array ]):
122+ return piecewise_quadratic (x , alpha = self .alpha , origin = self .origin )
123+
124+
108125@vjp_custom (['x' ], dict (alpha = 1. , origin = False ), dict (origin = [True , False ]))
109126def piecewise_quadratic (
110127 x : Union [jax .Array , Array ],
@@ -195,6 +212,15 @@ def grad(dz):
195212 return z , grad
196213
197214
215+ class PiecewiseExp :
216+ def __init__ (self , alpha = 1. , origin = False ):
217+ self .alpha = alpha
218+ self .origin = origin
219+
220+ def __call__ (self , x : Union [jax .Array , Array ]):
221+ return piecewise_exp (x , alpha = self .alpha , origin = self .origin )
222+
223+
198224@vjp_custom (['x' ], dict (alpha = 1. , origin = False ), dict (origin = [True , False ]))
199225def piecewise_exp (
200226 x : Union [jax .Array , Array ],
@@ -271,6 +297,15 @@ def grad(dz):
271297 return z , grad
272298
273299
300+ class SoftSign :
301+ def __init__ (self , alpha = 1. , origin = False ):
302+ self .alpha = alpha
303+ self .origin = origin
304+
305+ def __call__ (self , x : Union [jax .Array , Array ]):
306+ return soft_sign (x , alpha = self .alpha , origin = self .origin )
307+
308+
274309@vjp_custom (['x' ], dict (alpha = 1. , origin = False ), dict (origin = [True , False ]))
275310def soft_sign (
276311 x : Union [jax .Array , Array ],
@@ -342,6 +377,15 @@ def grad(dz):
342377 return z , grad
343378
344379
380+ class Arctan :
381+ def __init__ (self , alpha = 1. , origin = False ):
382+ self .alpha = alpha
383+ self .origin = origin
384+
385+ def __call__ (self , x : Union [jax .Array , Array ]):
386+ return arctan (x , alpha = self .alpha , origin = self .origin )
387+
388+
345389@vjp_custom (['x' ], dict (alpha = 1. , origin = False ), dict (origin = [True , False ]))
346390def arctan (
347391 x : Union [jax .Array , Array ],
@@ -412,6 +456,15 @@ def grad(dz):
412456 return z , grad
413457
414458
459+ class NonzeroSignLog :
460+ def __init__ (self , alpha = 1. , origin = False ):
461+ self .alpha = alpha
462+ self .origin = origin
463+
464+ def __call__ (self , x : Union [jax .Array , Array ]):
465+ return nonzero_sign_log (x , alpha = self .alpha , origin = self .origin )
466+
467+
415468@vjp_custom (['x' ], dict (alpha = 1. , origin = False ), statics = {'origin' : [True , False ]})
416469def nonzero_sign_log (
417470 x : Union [jax .Array , Array ],
@@ -495,6 +548,15 @@ def grad(dz):
495548 return z , grad
496549
497550
551+ class ERF :
552+ def __init__ (self , alpha = 1. , origin = False ):
553+ self .alpha = alpha
554+ self .origin = origin
555+
556+ def __call__ (self , x : Union [jax .Array , Array ]):
557+ return erf (x , alpha = self .alpha , origin = self .origin )
558+
559+
498560@vjp_custom (['x' ], dict (alpha = 1. , origin = False ), statics = {'origin' : [True , False ]})
499561def erf (
500562 x : Union [jax .Array , Array ],
@@ -569,12 +631,22 @@ def erf(
569631 z = jnp .asarray (x >= 0 , dtype = x .dtype )
570632
571633 def grad (dz ):
572- dx = (alpha / math .sqrt (math .pi )) * jnp .exp (- math . pow (alpha , 2 ) * x * x )
634+ dx = (alpha / jnp .sqrt (jnp .pi )) * jnp .exp (- jnp . power (alpha , 2 ) * x * x )
573635 return dx * as_jax (dz ), None
574636
575637 return z , grad
576638
577639
640+ class PiecewiseLeakyRelu :
641+ def __init__ (self , c = 0.01 , w = 1. , origin = False ):
642+ self .c = c
643+ self .w = w
644+ self .origin = origin
645+
646+ def __call__ (self , x : Union [jax .Array , Array ]):
647+ return piecewise_leaky_relu (x , c = self .c , w = self .w , origin = self .origin )
648+
649+
578650@vjp_custom (['x' ], dict (c = 0.01 , w = 1. , origin = False ), statics = {'origin' : [True , False ]})
579651def piecewise_leaky_relu (
580652 x : Union [jax .Array , Array ],
@@ -673,6 +745,16 @@ def grad(dz):
673745 return z , grad
674746
675747
748+ class SquarewaveFourierSeries :
749+ def __init__ (self , n = 2 , t_period = 8. , origin = False ):
750+ self .n = n
751+ self .t_period = t_period
752+ self .origin = origin
753+
754+ def __call__ (self , x : Union [jax .Array , Array ]):
755+ return squarewave_fourier_series (x , self .n , self .t_period , self .origin )
756+
757+
676758@vjp_custom (['x' ], dict (n = 2 , t_period = 8. , origin = False ), statics = {'origin' : [True , False ]})
677759def squarewave_fourier_series (
678760 x : Union [jax .Array , Array ],
@@ -732,13 +814,13 @@ def squarewave_fourier_series(
732814 The spiking state.
733815
734816 """
735- w = math .pi * 2. / t_period
817+ w = jnp .pi * 2. / t_period
736818 if origin :
737819 ret = jnp .sin (w * x )
738820 for i in range (2 , n ):
739821 c = (2 * i - 1. )
740822 ret += jnp .sin (c * w * x ) / c
741- z = 0.5 + 2. / math .pi * ret
823+ z = 0.5 + 2. / jnp .pi * ret
742824 else :
743825 z = jnp .asarray (x >= 0 , dtype = x .dtype )
744826
@@ -752,6 +834,17 @@ def grad(dz):
752834 return z , grad
753835
754836
837+ class S2NN :
838+ def __init__ (self , alpha = 4. , beta = 1. , epsilon = 1e-8 , origin = False ):
839+ self .alpha = alpha
840+ self .beta = beta
841+ self .epsilon = epsilon
842+ self .origin = origin
843+
844+ def __call__ (self , x : Union [jax .Array , Array ], ):
845+ return s2nn (x , self .alpha , self .beta , self .epsilon , self .origin )
846+
847+
755848@vjp_custom (['x' ],
756849 defaults = dict (alpha = 4. , beta = 1. , epsilon = 1e-8 , origin = False ),
757850 statics = {'origin' : [True , False ]})
@@ -844,6 +937,15 @@ def grad(dz):
844937 return z , grad
845938
846939
940+ class QPseudoSpike :
941+ def __init__ (self , alpha = 2. , origin = False ):
942+ self .alpha = alpha
943+ self .origin = origin
944+
945+ def __call__ (self , x : Union [jax .Array , Array ]):
946+ return q_pseudo_spike (x , self .alpha , self .origin )
947+
948+
847949@vjp_custom (['x' ],
848950 dict (alpha = 2. , origin = False ),
849951 statics = {'origin' : [True , False ]})
@@ -925,6 +1027,16 @@ def grad(dz):
9251027 return z , grad
9261028
9271029
1030+ class LeakyRelu :
1031+ def __init__ (self , alpha = 0.1 , beta = 1. , origin = False ):
1032+ self .alpha = alpha
1033+ self .beta = beta
1034+ self .origin = origin
1035+
1036+ def __call__ (self , x : Union [jax .Array , Array ]):
1037+ return leaky_relu (x , self .alpha , self .beta , self .origin )
1038+
1039+
9281040@vjp_custom (['x' ],
9291041 dict (alpha = 0.1 , beta = 1. , origin = False ),
9301042 statics = {'origin' : [True , False ]})
@@ -1006,6 +1118,15 @@ def grad(dz):
10061118 return z , grad
10071119
10081120
1121+ class LogTailedRelu :
1122+ def __init__ (self , alpha = 0. , origin = False ):
1123+ self .alpha = alpha
1124+ self .origin = origin
1125+
1126+ def __call__ (self , x : Union [jax .Array , Array ]):
1127+ return log_tailed_relu (x , self .alpha , self .origin )
1128+
1129+
10091130@vjp_custom (['x' ],
10101131 dict (alpha = 0. , origin = False ),
10111132 statics = {'origin' : [True , False ]})
@@ -1098,6 +1219,15 @@ def grad(dz):
10981219 return z , grad
10991220
11001221
1222+ class ReluGrad :
1223+ def __init__ (self , alpha = 0.3 , width = 1. ):
1224+ self .alpha = alpha
1225+ self .width = width
1226+
1227+ def __call__ (self , x : Union [jax .Array , Array ]):
1228+ return relu_grad (x , self .alpha , self .width )
1229+
1230+
11011231@vjp_custom (['x' ], dict (alpha = 0.3 , width = 1. ))
11021232def relu_grad (
11031233 x : Union [jax .Array , Array ],
@@ -1163,6 +1293,15 @@ def grad(dz):
11631293 return z , grad
11641294
11651295
1296+ class GaussianGrad :
1297+ def __init__ (self , sigma = 0.5 , alpha = 0.5 ):
1298+ self .sigma = sigma
1299+ self .alpha = alpha
1300+
1301+ def __call__ (self , x : Union [jax .Array , Array ]):
1302+ return gaussian_grad (x , self .sigma , self .alpha )
1303+
1304+
11661305@vjp_custom (['x' ], dict (sigma = 0.5 , alpha = 0.5 ))
11671306def gaussian_grad (
11681307 x : Union [jax .Array , Array ],
@@ -1221,12 +1360,23 @@ def gaussian_grad(
12211360 z = jnp .asarray (x >= 0 , dtype = x .dtype )
12221361
12231362 def grad (dz ):
1224- dx = jnp .exp (- (x ** 2 ) / 2 * math . pow (sigma , 2 )) / (math .sqrt (2 * math .pi ) * sigma )
1363+ dx = jnp .exp (- (x ** 2 ) / 2 * jnp . power (sigma , 2 )) / (jnp .sqrt (2 * jnp .pi ) * sigma )
12251364 return alpha * dx * as_jax (dz ), None , None
12261365
12271366 return z , grad
12281367
12291368
1369+ class MultiGaussianGrad :
1370+ def __init__ (self , h = 0.15 , s = 6.0 , sigma = 0.5 , scale = 0.5 ):
1371+ self .h = h
1372+ self .s = s
1373+ self .sigma = sigma
1374+ self .scale = scale
1375+
1376+ def __call__ (self , x : Union [jax .Array , Array ]):
1377+ return multi_gaussian_grad (x , self .h , self .s , self .sigma , self .scale )
1378+
1379+
12301380@vjp_custom (['x' ], dict (h = 0.15 , s = 6.0 , sigma = 0.5 , scale = 0.5 ))
12311381def multi_gaussian_grad (
12321382 x : Union [jax .Array , Array ],
@@ -1294,15 +1444,23 @@ def multi_gaussian_grad(
12941444 z = jnp .asarray (x >= 0 , dtype = x .dtype )
12951445
12961446 def grad (dz ):
1297- g1 = jnp .exp (- x ** 2 / (2 * math . pow (sigma , 2 ))) / (math .sqrt (2 * math .pi ) * sigma )
1298- g2 = jnp .exp (- (x - sigma ) ** 2 / (2 * math . pow (s * sigma , 2 ))) / (math .sqrt (2 * math .pi ) * s * sigma )
1299- g3 = jnp .exp (- (x + sigma ) ** 2 / (2 * math . pow (s * sigma , 2 ))) / (math .sqrt (2 * math .pi ) * s * sigma )
1447+ g1 = jnp .exp (- x ** 2 / (2 * jnp . power (sigma , 2 ))) / (jnp .sqrt (2 * jnp .pi ) * sigma )
1448+ g2 = jnp .exp (- (x - sigma ) ** 2 / (2 * jnp . power (s * sigma , 2 ))) / (jnp .sqrt (2 * jnp .pi ) * s * sigma )
1449+ g3 = jnp .exp (- (x + sigma ) ** 2 / (2 * jnp . power (s * sigma , 2 ))) / (jnp .sqrt (2 * jnp .pi ) * s * sigma )
13001450 dx = g1 * (1. + h ) - g2 * h - g3 * h
13011451 return scale * dx * as_jax (dz ), None , None , None , None
13021452
13031453 return z , grad
13041454
13051455
1456+ class InvSquareGrad :
1457+ def __init__ (self , alpha = 100. ):
1458+ self .alpha = alpha
1459+
1460+ def __call__ (self , x : Union [jax .Array , Array ]):
1461+ return inv_square_grad (x , self .alpha )
1462+
1463+
13061464@vjp_custom (['x' ], dict (alpha = 100. ))
13071465def inv_square_grad (
13081466 x : Union [jax .Array , Array ],
@@ -1360,6 +1518,14 @@ def grad(dz):
13601518 return z , grad
13611519
13621520
1521+ class SlayerGrad :
1522+ def __init__ (self , alpha = 1. ):
1523+ self .alpha = alpha
1524+
1525+ def __call__ (self , x : Union [jax .Array , Array ]):
1526+ return slayer_grad (x , self .alpha )
1527+
1528+
13631529@vjp_custom (['x' ], dict (alpha = 1. ))
13641530def slayer_grad (
13651531 x : Union [jax .Array , Array ],
0 commit comments