Skip to content

Commit 56a059f

Browse files
author
Alexander Ororbia
committed
refactored conv/deconv-hebb-syn and tests passed
1 parent 55b0219 commit 56a059f

File tree

7 files changed

+197
-244
lines changed

7 files changed

+197
-244
lines changed

ngclearn/components/neurons/spiking/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## point to standard spiking cell component types
2-
from .sLIFCell import SLIFCell
2+
# from .sLIFCell import SLIFCell
33
from .LIFCell import LIFCell
44
from .IFCell import IFCell
55
from .WTASCell import WTASCell
@@ -9,3 +9,4 @@
99
from .izhikevichCell import IzhikevichCell
1010
from .RAFCell import RAFCell
1111
from .hodgkinHuxleyCell import HodgkinHuxleyCell
12+

ngclearn/components/synapses/__init__.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,32 @@
77
from .exponentialSynapse import ExponentialSynapse
88
from .doubleExpSynapse import DoupleExpSynapse
99
from .alphaSynapse import AlphaSynapse
10-
#
11-
# ## dense synaptic components
10+
11+
## dense synaptic components
1212
# from .hebbian.hebbianSynapse import HebbianSynapse
13-
# from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
14-
# from .hebbian.expSTDPSynapse import ExpSTDPSynapse
15-
# from .hebbian.eventSTDPSynapse import EventSTDPSynapse
16-
# from .hebbian.BCMSynapse import BCMSynapse
17-
#
18-
#
19-
# ## conv/deconv synaptic components
20-
# from .convolution.convSynapse import ConvSynapse
21-
# from .convolution.staticConvSynapse import StaticConvSynapse
22-
# from .convolution.hebbianConvSynapse import HebbianConvSynapse
23-
# from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
24-
# from .convolution.deconvSynapse import DeconvSynapse
25-
# from .convolution.staticDeconvSynapse import StaticDeconvSynapse
26-
# from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
27-
# from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
28-
#
29-
#
30-
# ## modulated synaptic components
31-
# from .modulated.MSTDPETSynapse import MSTDPETSynapse
13+
from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
14+
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
15+
from .hebbian.eventSTDPSynapse import EventSTDPSynapse
16+
from .hebbian.BCMSynapse import BCMSynapse
17+
18+
19+
## conv/deconv synaptic components
20+
from .convolution.convSynapse import ConvSynapse
21+
from .convolution.staticConvSynapse import StaticConvSynapse
22+
from .convolution.hebbianConvSynapse import HebbianConvSynapse
23+
from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
24+
from .convolution.deconvSynapse import DeconvSynapse
25+
from .convolution.staticDeconvSynapse import StaticDeconvSynapse
26+
from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
27+
from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
28+
29+
30+
## modulated synaptic components
31+
from .modulated.MSTDPETSynapse import MSTDPETSynapse
3232
# from .modulated.REINFORCESynapse import REINFORCESynapse
33-
#
34-
# ## patched synaptic components
33+
34+
## patched synaptic components
3535
# from .patched.patchedSynapse import PatchedSynapse
3636
# from .patched.staticPatchedSynapse import StaticPatchedSynapse
3737
# from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse
38-
#
38+

ngclearn/components/synapses/convolution/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from .staticConvSynapse import StaticConvSynapse
33
from .deconvSynapse import DeconvSynapse
44
from .staticDeconvSynapse import StaticDeconvSynapse
5-
#from .hebbianConvSynapse import HebbianConvSynapse
6-
# from .hebbianDeconvSynapse import HebbianDeconvSynapse
5+
from .hebbianConvSynapse import HebbianConvSynapse
6+
from .hebbianDeconvSynapse import HebbianDeconvSynapse
77
from .traceSTDPConvSynapse import TraceSTDPConvSynapse
88
from .traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
9+

ngclearn/components/synapses/convolution/hebbianConvSynapse.py

Lines changed: 65 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
42
from ngcsimlib.compartment import Compartment
5-
6-
from .convSynapse import ConvSynapse
3+
from ngcsimlib.parser import compilable
74
from ngclearn.utils.weight_distribution import initialize_params
8-
from ngcsimlib.logger import info
9-
from ngclearn.utils import tensorstats
105
import ngclearn.utils.weight_distribution as dist
6+
7+
from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse
8+
119
from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding,
1210
_conv_valid_transpose_padding)
1311
from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv,
@@ -17,8 +15,7 @@
1715

1816
class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
1917
"""
20-
A synaptic convolutional cable that adjusts its efficacies via a two-factor
21-
Hebbian adjustment rule.
18+
A specialized synaptic convolutional cable that adjusts its efficacies via a two-factor Hebbian adjustment rule.
2219
2320
| --- Synapse Compartments: ---
2421
| inputs - input (takes in external signals)
@@ -88,10 +85,11 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
8885
"""
8986

9087
# Define Functions
91-
def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
92-
stride=1, padding=None, resist_scale=1., w_bound=0.,
93-
is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
94-
batch_size=1, **kwargs):
88+
def __init__(
89+
self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None,
90+
resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
91+
batch_size=1, **kwargs
92+
):
9593
super().__init__(
9694
name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=bias_init, resist_scale=resist_scale,
9795
stride=stride, padding=padding, batch_size=batch_size, **kwargs
@@ -107,9 +105,9 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
107105

108106
######################### set up compartments ##########################
109107
## Compartment setup and shape computation
110-
self.dWeights = Compartment(self.weights.value * 0)
108+
self.dWeights = Compartment(self.weights.get() * 0)
111109
self.dInputs = Compartment(jnp.zeros(self.in_shape))
112-
self.dBiases = Compartment(self.biases.value * 0)
110+
self.dBiases = Compartment(self.biases.get() * 0)
113111
self.pre = Compartment(jnp.zeros(self.in_shape))
114112
self.post = Compartment(jnp.zeros(self.out_shape))
115113

@@ -120,103 +118,97 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
120118
self.antiPad = None
121119
k_size, k_size, n_in_chan, n_out_chan = self.shape
122120
if padding == "SAME":
123-
self.antiPad = _conv_same_transpose_padding(self.post.value.shape[1],
121+
self.antiPad = _conv_same_transpose_padding(self.post.get().shape[1],
124122
self.x_size, k_size, stride)
125123
elif padding == "VALID":
126-
self.antiPad = _conv_valid_transpose_padding(self.post.value.shape[1],
124+
self.antiPad = _conv_valid_transpose_padding(self.post.get().shape[1],
127125
self.x_size, k_size, stride)
128126

129127
########################################################################
130128

131129
## set up outer optimization compartments
132130
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
133-
[self.weights.value, self.biases.value]
134-
if bias_init else [self.weights.value]))
131+
[self.weights.get(), self.biases.get()]
132+
if bias_init else [self.weights.get()])
133+
)
135134

136135
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
137136
k_size, k_size, n_in_chan, n_out_chan = shape
138137
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
139-
_d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0
138+
_d = conv2d(_x, weights.get(), stride_size=stride, padding=padding) * 0
140139
_dK = _calc_dK_conv(_x, _d, stride_size=stride, padding=pad_args)
141140
## get filter update correction
142-
dx = _dK.shape[0] - weights.value.shape[0]
143-
dy = _dK.shape[1] - weights.value.shape[1]
141+
dx = _dK.shape[0] - weights.get().shape[0]
142+
dy = _dK.shape[1] - weights.get().shape[1]
144143
self.delta_shape = (max(dx, 0), max(dy, 0))
145144
## get input update correction
146-
_dx = _calc_dX_conv(weights.value, _d, stride_size=stride,
147-
anti_padding=pad_args)
145+
_dx = _calc_dX_conv(weights.get(), _d, stride_size=stride, anti_padding=pad_args)
148146
dx = (_dx.shape[1] - _x.shape[1])
149147
dy = (_dx.shape[2] - _x.shape[2])
150148
self.x_delta_shape = (dx, dy)
151149

152-
@staticmethod
153-
def _compute_update(
154-
sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
155-
): ## synaptic kernel adjustment calculation co-routine
150+
def _compute_update(self): #sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
151+
## synaptic kernel adjustment calculation co-routine
156152
## compute adjustment to filters
157-
dWeights = calc_dK_conv(pre, post, delta_shape=delta_shape, stride_size=stride, padding=pad_args)
158-
dWeights = dWeights * sign_value
159-
if w_decay > 0.: ## apply synaptic decay
160-
dWeights = dWeights - weights * w_decay
153+
dWeights = calc_dK_conv(
154+
self.pre.get(), self.post.get(), delta_shape=self.delta_shape, stride_size=self.stride, padding=self.pad_args
155+
)
156+
dWeights = dWeights * self.sign_value
157+
if self.w_decay > 0.: ## apply synaptic decay
158+
dWeights = dWeights - self.weights.get() * self.w_decay
161159
## compute adjustment to base-rates (if applicable)
162160
dBiases = 0. # jnp.zeros((1,1))
163-
if bias_init != None:
164-
dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value
161+
if self.bias_init != None:
162+
dBiases = jnp.sum(self.post.get(), axis=0, keepdims=True) * self.sign_value
165163
return dWeights, dBiases
166164

167-
@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
168-
@staticmethod
169-
def evolve(
170-
opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, stride, pad_args, delta_shape, pre, post,
171-
weights, biases, opt_params
172-
):
165+
@compilable
166+
def evolve(self):
173167
## calc dFilters / dBiases - update to filters and biases
174-
dWeights, dBiases = HebbianConvSynapse._compute_update(
175-
sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
176-
)
177-
if bias_init != None:
178-
opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
168+
dWeights, dBiases = self._compute_update()
169+
if self.bias_init is not None:
170+
opt_params, [weights, biases] = self.opt(self.opt_params.get(), [self.weights.get(), self.biases.get()], [dWeights, dBiases])
179171
else: ## ignore dBiases since no biases configured
180-
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
181-
172+
opt_params, [weights] = self.opt(self.opt_params.get(), [self.weights.get()], [dWeights])
173+
biases = None
182174
## apply any enforced filter constraints
183-
if w_bounds > 0.:
184-
if is_nonnegative:
185-
weights = jnp.clip(weights, 0., w_bounds)
175+
if self.w_bounds > 0.:
176+
if self.is_nonnegative:
177+
weights = jnp.clip(weights, 0., self.w_bounds)
186178
else:
187-
weights = jnp.clip(weights, -w_bounds, w_bounds)
188-
return opt_params, weights, biases, dWeights, dBiases
189-
190-
@transition(output_compartments=["dInputs"])
191-
@staticmethod
192-
def backtransmit(
193-
sign_value, x_size, shape, stride, padding, x_delta_shape, antiPad, post, weights
194-
): ## action-backpropagating routine
179+
weights = jnp.clip(weights, -self.w_bounds, self.w_bounds)
180+
181+
self.opt_params.set(opt_params)
182+
self.weights.set(weights)
183+
self.biases.set(biases)
184+
self.dWeights.set(dWeights)
185+
self.dBiases.set(dBiases)
186+
187+
@compilable
188+
def backtransmit(self): ## action-backpropagating co-routine
195189
## calc dInputs - adjustment w.r.t. input signal
196-
k_size, k_size, n_in_chan, n_out_chan = shape
190+
k_size, k_size, n_in_chan, n_out_chan = self.shape
197191
# antiPad = None
198192
# if padding == "SAME":
199193
# antiPad = _conv_same_transpose_padding(post.shape[1], x_size,
200194
# k_size, stride)
201195
# elif padding == "VALID":
202196
# antiPad = _conv_valid_transpose_padding(post.shape[1], x_size,
203197
# k_size, stride)
204-
dInputs = calc_dX_conv(weights, post, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad)
198+
dInputs = calc_dX_conv(self.weights.get(), self.post.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, anti_padding=self.antiPad)
205199
## flip sign of back-transmitted signal (if applicable)
206-
dInputs = dInputs * sign_value
207-
return dInputs
208-
209-
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"])
210-
@staticmethod
211-
def reset(in_shape, out_shape):
212-
preVals = jnp.zeros(in_shape)
213-
postVals = jnp.zeros(out_shape)
214-
inputs = preVals
215-
outputs = postVals
216-
pre = preVals
217-
post = postVals
218-
dInputs = preVals
219-
return inputs, outputs, pre, post, dInputs
200+
dInputs = dInputs * self.sign_value
201+
self.dInputs.set(dInputs)
202+
203+
@compilable
204+
def reset(self): #in_shape, out_shape):
205+
preVals = jnp.zeros(self.in_shape.get())
206+
postVals = jnp.zeros(self.out_shape.get())
207+
self.inputs.set(preVals)
208+
self.outputs.set(postVals)
209+
self.pre.set(preVals)
210+
self.post.set(postVals)
211+
self.dInputs.set(preVals)
220212

221213
@classmethod
222214
def help(cls): ## component help function

0 commit comments

Comments
 (0)