@@ -128,15 +128,13 @@ class HebbianSynapse(DenseSynapse):
128128 prior: a kernel to drive prior of this synaptic cable's values;
129129 typically a tuple with 1st element as a string calling the name of
130130 prior to use and 2nd element as a floating point number
131- calling the prior parameter lambda (Default: (None , 0.))
132- currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
131+ calling the prior parameter lambda (Default: ('constant' , 0.))
132+ currently it supports "l1"/ "lasso"/"laplacian" or "l2"/ "ridge"/"gaussian" or "l1l2"/ "elastic_net".
133133 usage guide:
134134 prior = ('l1', 0.01) or prior = ('lasso', lmbda)
135135 prior = ('l2', 0.01) or prior = ('ridge', lmbda)
136136 prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
137137
138-
139-
140138 sign_value: multiplicative factor to apply to final synaptic update before
141139 it is applied to synapses; this is useful if gradient descent style
142140 optimization is required (as Hebbian rules typically yield
@@ -165,7 +163,7 @@ class HebbianSynapse(DenseSynapse):
165163 # Define Functions
166164 @deprecate_args (_rebind = False , w_decay = 'prior' )
167165 def __init__ (self , name , shape , eta = 0. , weight_init = None , bias_init = None ,
168- w_bound = 1. , is_nonnegative = False , prior = (None , 0. ), w_decay = 0. , sign_value = 1. ,
166+ w_bound = 1. , is_nonnegative = False , prior = ("constant" , 0. ), w_decay = 0. , sign_value = 1. ,
169167 optim_type = "sgd" , pre_wght = 1. , post_wght = 1. , p_conn = 1. ,
170168 resist_scale = 1. , batch_size = 1 , ** kwargs ):
171169 super ().__init__ (name , shape , weight_init , bias_init , resist_scale ,
@@ -175,6 +173,8 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
175173 prior = ('l2' , w_decay )
176174
177175 prior_type , prior_lmbda = prior
176+ if prior_type is None :
177+ prior_type = "constant"
178178 ## synaptic plasticity properties and characteristics
179179 self .shape = shape
180180 self .Rscale = resist_scale
0 commit comments