11from jax import random , numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
36from .convSynapse import ConvSynapse
7+ from ngclearn .utils .weight_distribution import initialize_params
8+ from ngcsimlib .logger import info
9+ from ngclearn .utils import tensorstats
10+ import ngclearn .utils .weight_distribution as dist
411from ngclearn .components .synapses .convolution .ngcconv import (_conv_same_transpose_padding ,
512 _conv_valid_transpose_padding )
613from ngclearn .components .synapses .convolution .ngcconv import (conv2d , _calc_dX_conv ,
@@ -67,12 +74,14 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
6774 """
6875
6976 # Define Functions
70- def __init__ (self , name , shape , x_shape , A_plus , A_minus , eta = 0. ,
71- pretrace_target = 0. , filter_init = None , stride = 1 , padding = None ,
72- resist_scale = 1. , w_bound = 0. , w_decay = 0. , batch_size = 1 , ** kwargs ):
73- super ().__init__ (name , shape , x_shape = x_shape , filter_init = filter_init ,
74- bias_init = None , resist_scale = resist_scale , stride = stride ,
75- padding = padding , batch_size = batch_size , ** kwargs )
77+ def __init__ (
78+ self , name , shape , x_shape , A_plus , A_minus , eta = 0. , pretrace_target = 0. , filter_init = None , stride = 1 ,
79+ padding = None , resist_scale = 1. , w_bound = 0. , w_decay = 0. , batch_size = 1 , ** kwargs
80+ ):
81+ super ().__init__ (
82+ name , shape , x_shape = x_shape , filter_init = filter_init , bias_init = None , resist_scale = resist_scale ,
83+ stride = stride , padding = padding , batch_size = batch_size , ** kwargs
84+ )
7685
7786 self .eta = eta
7887 self .w_bound = w_bound ## soft weight constraint
@@ -107,8 +116,7 @@ def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
107116 self .x_size , k_size , stride )
108117 ########################################################################
109118
110- def _init (self , batch_size , x_size , shape , stride , padding , pad_args ,
111- weights ):
119+ def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
112120 k_size , k_size , n_in_chan , n_out_chan = shape
113121 _x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
114122 _d = conv2d (_x , weights .value , stride_size = stride , padding = padding ) * 0
@@ -126,26 +134,28 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
126134 self .x_delta_shape = (dx , dy )
127135
128136 @staticmethod
129- def _compute_update (pretrace_target , Aplus , Aminus , stride , pad_args ,
130- delta_shape , preSpike , preTrace , postSpike , postTrace ):
137+ def _compute_update (
138+ pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape , preSpike , preTrace , postSpike , postTrace
139+ ):
131140 ## Compute long-term potentiation to filters
132- dW_ltp = calc_dK_conv (preTrace - pretrace_target , postSpike * Aplus ,
133- delta_shape = delta_shape , stride_size = stride ,
134- padding = pad_args )
141+ dW_ltp = calc_dK_conv (
142+ preTrace - pretrace_target , postSpike * Aplus , delta_shape = delta_shape , stride_size = stride , padding = pad_args
143+ )
135144 ## Compute long-term depression to filters
136- dW_ltd = - calc_dK_conv (preSpike , postTrace * Aminus ,
137- delta_shape = delta_shape , stride_size = stride ,
138- padding = pad_args )
145+ dW_ltd = - calc_dK_conv (
146+ preSpike , postTrace * Aminus , delta_shape = delta_shape , stride_size = stride , padding = pad_args
147+ )
139148 dWeights = (dW_ltp + dW_ltd )
140149 return dWeights
141150
151+ @transition (output_compartments = ["weights" , "dWeights" ])
142152 @staticmethod
143- def _evolve (pretrace_target , Aplus , Aminus , w_decay , w_bound ,
144- stride , pad_args , delta_shape , preSpike , preTrace , postSpike ,
145- postTrace , weights , eta ):
153+ def evolve (
154+ pretrace_target , Aplus , Aminus , w_decay , w_bound , stride , pad_args , delta_shape , preSpike , preTrace ,
155+ postSpike , postTrace , weights , eta
156+ ):
146157 dWeights = TraceSTDPConvSynapse ._compute_update (
147- pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape ,
148- preSpike , preTrace , postSpike , postTrace
158+ pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape , preSpike , preTrace , postSpike , postTrace
149159 )
150160 if w_decay > 0. : ## apply synaptic decay
151161 weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
@@ -157,14 +167,11 @@ def _evolve(pretrace_target, Aplus, Aminus, w_decay, w_bound,
157167 weights = jnp .clip (weights , eps , w_bound - eps )
158168 return weights , dWeights
159169
160- @resolver (_evolve )
161- def evolve (self , weights , dWeights ):
162- self .weights .set (weights )
163- self .dWeights .set (dWeights )
164-
170+ @transition (output_compartments = ["dInputs" ])
165171 @staticmethod
166- def _backtransmit (x_size , shape , stride , padding , x_delta_shape , antiPad ,
167- postSpike , weights ): ## action-backpropagating routine
172+ def backtransmit (
173+ x_size , shape , stride , padding , x_delta_shape , antiPad , postSpike , weights
174+ ): ## action-backpropagating routine
168175 ## calc dInputs - adjustment w.r.t. input signal
169176 k_size , k_size , n_in_chan , n_out_chan = shape
170177 # antiPad = None
@@ -174,16 +181,12 @@ def _backtransmit(x_size, shape, stride, padding, x_delta_shape, antiPad,
174181 # elif padding == "VALID":
175182 # antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
176183 # k_size, stride)
177- dInputs = calc_dX_conv (weights , postSpike , delta_shape = x_delta_shape ,
178- stride_size = stride , anti_padding = antiPad )
184+ dInputs = calc_dX_conv (weights , postSpike , delta_shape = x_delta_shape , stride_size = stride , anti_padding = antiPad )
179185 return dInputs
180186
181- @resolver (_backtransmit )
182- def backtransmit (self , dInputs ):
183- self .dInputs .set (dInputs )
184-
187+ @transition (output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" ])
185188 @staticmethod
186- def _reset (in_shape , out_shape ):
189+ def reset (in_shape , out_shape ):
187190 preVals = jnp .zeros (in_shape )
188191 postVals = jnp .zeros (out_shape )
189192 inputs = preVals
@@ -194,15 +197,6 @@ def _reset(in_shape, out_shape):
194197 postTrace = postVals
195198 return inputs , outputs , preSpike , postSpike , preTrace , postTrace
196199
197- @resolver (_reset )
198- def reset (self , inputs , outputs , preSpike , postSpike , preTrace , postTrace ):
199- self .inputs .set (inputs )
200- self .outputs .set (outputs )
201- self .preSpike .set (preSpike )
202- self .postSpike .set (postSpike )
203- self .preTrace .set (preTrace )
204- self .postTrace .set (postTrace )
205-
206200 @classmethod
207201 def help (cls ): ## component help function
208202 properties = {
0 commit comments