Skip to content

Commit 5ec052f

Browse files
author
Alexander Ororbia
committed
refactored regression module to be compliant with v3
1 parent a03480a commit 5ec052f

File tree

5 files changed

+85
-190
lines changed

5 files changed

+85
-190
lines changed

ngclearn/modules/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,3 @@
22
from .regression.lasso import Iterative_Lasso
33
from .regression.ridge import Iterative_Ridge
44

5-
6-
7-
8-

ngclearn/modules/regression/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,3 @@
22
from .lasso import Iterative_Lasso
33
from .ridge import Iterative_Ridge
44

5-
6-
7-
8-
9-

ngclearn/modules/regression/elastic_net.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class Iterative_ElasticNet():
1212
"""
1313
A neural circuit implementation of the iterative Elastic Net (L1 and L2) algorithm
14-
using Hebbian learning update rule.
14+
using a Hebbian learning update rule.
1515
1616
The circuit implements sparse regression through Hebbian synapses with Elastic Net regularization.
1717
@@ -21,8 +21,6 @@ class Iterative_ElasticNet():
2121
| dW_reg = (jnp.sign(W) * l1_ratio) + (W * (1-l1_ratio)/2)
2222
| dW/dt = dW + lmbda * dW_reg
2323
24-
25-
2624
| --- Circuit Components: ---
2725
| W - HebbianSynapse for learning regularized dictionary weights
2826
| err - GaussianErrorCell for computing prediction errors
@@ -104,14 +102,6 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
104102
>> self.W.reset)
105103
self.reset = reset
106104

107-
# advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
108-
# self.err, ## finally, execute error neurons
109-
# compile_key="advance_state")
110-
# evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
111-
# reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
112-
# # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
113-
# self.dynamic()
114-
115105
def batch_set(self, batch_size):
116106
self.W.batch_size = batch_size
117107
self.err.batch_size = batch_size
@@ -121,33 +111,6 @@ def clamp(self, y_scaled, X):
121111
self.W.pre.set(X)
122112
self.err.target.set(y_scaled)
123113

124-
# def dynamic(self): ## create dynamic commands forself.circuit
125-
# W, err = self.circuit.get_components("W", "err")
126-
# self.self = W
127-
# self.err = err
128-
#
129-
# @Context.dynamicCommand
130-
# def batch_set(batch_size):
131-
# self.W.batch_size = batch_size
132-
# self.err.batch_size = batch_size
133-
#
134-
# @Context.dynamicCommand
135-
# def clamps(y_scaled, X):
136-
# self.W.inputs.set(X)
137-
# self.W.pre.set(X)
138-
# self.err.target.set(y_scaled)
139-
#
140-
# self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
141-
# self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
142-
# self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
143-
#
144-
# @scanner
145-
# def _process(compartment_values, args):
146-
# _t, _dt = args
147-
# compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
148-
# return compartment_values, compartment_values[self.W.weights.path]
149-
150-
151114
def thresholding(self, scale=1.):
152115
coef_old = self.coef_
153116
new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.)
@@ -172,18 +135,3 @@ def fit(self, y, X):
172135

173136
return self.coef_, self.err.mu.get(), self.err.L.get()
174137

175-
# self.circuit.reset()
176-
# self.circuit.clamps(y_scaled=y, X=X)
177-
#
178-
# for i in range(self.epochs):
179-
# self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
180-
# self.circuit.evolve(t=self.T, dt=self.dt)
181-
#
182-
# self.coef_ = np.array(self.W.weights.value)
183-
#
184-
# return self.coef_, self.err.mu.value, self.err.L.value
185-
186-
187-
188-
189-
Lines changed: 40 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,19 @@
1-
import jax
2-
import pandas as pd
3-
from jax import random, jit
41
import numpy as np
5-
from scipy.integrate import solve_ivp
6-
import matplotlib.pyplot as plt
7-
from ngcsimlib.utils import Get_Compartment_Batch
8-
from ngclearn.utils.model_utils import normalize_matrix
92
from ngclearn.utils import weight_distribution as dist
10-
from ngclearn import Context, numpy as jnp
11-
from ngclearn.components import (RateCell,
12-
HebbianSynapse,
13-
GaussianErrorCell,
14-
StaticSynapse)
15-
from ngclearn.utils.model_utils import scanner
3+
from ngclearn import numpy as jnp
164

5+
from jax import numpy as jnp, random, jit
6+
from ngclearn import Context, MethodProcess
7+
from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
8+
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
9+
from ngcsimlib.global_state import stateManager
1710

1811
class Iterative_Lasso():
1912
"""
2013
A neural circuit implementation of the iterative Lasso (L1) algorithm
21-
using Hebbian learning update rule.
14+
using a Hebbian learning update rule.
2215
23-
The circuit implements sparse coding through Hebbian synapses with L1 regularization.
16+
The circuit implements sparse coding-like regression through Hebbian synapses with L1 regularization.
2417
2518
The specific differential equation that characterizes this model is adding lmbda * sign(W)
2619
to the dW (the gradient of loss/energy function):
@@ -89,43 +82,32 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
8982
self.W.batch_size = batch_size
9083
self.err.batch_size = batch_size
9184
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92-
self.err.mu << self.W.outputs
93-
self.W.post << self.err.dmu
85+
self.W.outputs >> self.err.mu
86+
self.err.dmu >> self.W.post
9487
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
95-
advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
96-
self.err, ## finally, execute error neurons
97-
compile_key="advance_state")
98-
evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
99-
reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
100-
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101-
self.dynamic()
102-
103-
def dynamic(self): ## create dynamic commands for self.circuit
104-
W, err = self.circuit.get_components("W", "err")
105-
self.self = W
106-
self.err = err
107-
108-
@Context.dynamicCommand
109-
def batch_set(batch_size):
110-
self.W.batch_size = batch_size
111-
self.err.batch_size = batch_size
112-
113-
@Context.dynamicCommand
114-
def clamps(y_scaled, X):
115-
self.W.inputs.set(X)
116-
self.W.pre.set(X)
117-
self.err.target.set(y_scaled)
118-
119-
self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
120-
self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
121-
self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
122-
123-
@scanner
124-
def _process(compartment_values, args):
125-
_t, _dt = args
126-
compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
127-
return compartment_values, compartment_values[self.W.weights.path]
128-
88+
89+
advance = (MethodProcess(name="advance_state")
90+
>> self.W.advance_state
91+
>> self.err.advance_state)
92+
self.advance = advance
93+
94+
evolve = (MethodProcess(name="evolve")
95+
>> self.W.evolve)
96+
self.evolve = evolve
97+
98+
reset = (MethodProcess(name="reset")
99+
>> self.err.reset
100+
>> self.W.reset)
101+
self.reset = reset
102+
103+
def batch_set(self, batch_size):
104+
self.W.batch_size = batch_size
105+
self.err.batch_size = batch_size
106+
107+
def clamp(self, y_scaled, X):
108+
self.W.inputs.set(X)
109+
self.W.pre.set(X)
110+
self.err.target.set(y_scaled)
129111

130112
def thresholding(self, scale=2):
131113
coef_old = self.coef_
@@ -136,23 +118,16 @@ def thresholding(self, scale=2):
136118

137119
return self.coef_, coef_old
138120

139-
140121
def fit(self, y, X):
141-
142-
self.circuit.reset()
143-
self.circuit.clamps(y_scaled=y, X=X)
122+
self.reset.run()
123+
self.clamp(y_scaled=y, X=X)
144124

145125
for i in range(self.epochs):
146-
self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
147-
self.circuit.evolve(t=self.T, dt=self.dt)
148-
149-
self.coef_ = np.array(self.W.weights.value)
150-
151-
return self.coef_, self.err.mu.value, self.err.L.value
152-
153-
154-
155-
126+
inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt))
127+
stateManager.state, outputs = self.advance.scan(inputs)
128+
self.evolve.run(t=self.T, dt=self.dt)
156129

130+
self.coef_ = np.array(self.W.weights.get())
157131

132+
return self.coef_, self.err.mu.get(), self.err.L.get()
158133

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
from jax import random, jit
21
import numpy as np
32
from ngclearn.utils import weight_distribution as dist
4-
from ngclearn import Context, numpy as jnp
5-
from ngclearn.components import (RateCell,
6-
HebbianSynapse,
7-
GaussianErrorCell,
8-
StaticSynapse)
9-
from ngclearn.utils.model_utils import scanner
10-
3+
from ngclearn import numpy as jnp
114

5+
from jax import numpy as jnp, random, jit
6+
from ngclearn import Context, MethodProcess
7+
from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
8+
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
9+
from ngcsimlib.global_state import stateManager
1210

1311
class Iterative_Ridge():
1412
"""
1513
A neural circuit implementation of the iterative Ridge (L2) algorithm
16-
using Hebbian learning update rule.
14+
using a Hebbian learning update rule.
1715
18-
The circuit implements sparse regression through Hebbian synapses with L2 regularization.
16+
This circuit implements sparse regression through Hebbian synapses with L2 regularization.
1917
2018
The specific differential equation that characterizes this model is adding lmbda * W
2119
to the dW (the gradient of loss/energy function):
@@ -75,54 +73,43 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
7573
feature_dim = dict_dim
7674

7775
with Context(self.name) as self.circuit:
78-
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
79-
sign_value=-1, weight_init=dist.constant(weight_fill),
80-
prior=('ridge', ridge_lmbda), w_bound=0.,
81-
optim_type=optim_type, key=subkeys[0])
76+
self.W = HebbianSynapse(
77+
"W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
78+
weight_init=dist.constant(weight_fill), prior=('ridge', ridge_lmbda), w_bound=0.,
79+
optim_type=optim_type, key=subkeys[0]
80+
)
8281
self.err = GaussianErrorCell("err", n_units=sys_dim)
8382

8483
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8584
self.W.batch_size = batch_size
8685
self.err.batch_size = batch_size
8786
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88-
self.err.mu << self.W.outputs
89-
self.W.post << self.err.dmu
87+
self.W.outputs >> self.err.mu
88+
self.err.dmu >> self.W.post
9089
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91-
advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
92-
self.err, ## finally, execute error neurons
93-
compile_key="advance_state")
94-
evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
95-
reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
96-
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97-
self.dynamic()
98-
99-
def dynamic(self): ## create dynamic commands forself.circuit
100-
W, err = self.circuit.get_components("W", "err")
101-
self.self = W
102-
self.err = err
103-
104-
@Context.dynamicCommand
105-
def batch_set(batch_size):
106-
self.W.batch_size = batch_size
107-
self.err.batch_size = batch_size
108-
109-
@Context.dynamicCommand
110-
def clamps(y_scaled, X):
111-
self.W.inputs.set(X)
112-
self.W.pre.set(X)
113-
self.err.target.set(y_scaled)
114-
115-
self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
116-
self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
117-
self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
118-
119-
120-
@scanner
121-
def _process(compartment_values, args):
122-
_t, _dt = args
123-
compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
124-
return compartment_values, compartment_values[self.W.weights.path]
12590

91+
advance = (MethodProcess(name="advance_state")
92+
>> self.W.advance_state
93+
>> self.err.advance_state)
94+
self.advance = advance
95+
96+
evolve = (MethodProcess(name="evolve")
97+
>> self.W.evolve)
98+
self.evolve = evolve
99+
100+
reset = (MethodProcess(name="reset")
101+
>> self.err.reset
102+
>> self.W.reset)
103+
self.reset = reset
104+
105+
def batch_set(self, batch_size):
106+
self.W.batch_size = batch_size
107+
self.err.batch_size = batch_size
108+
109+
def clamp(self, y_scaled, X):
110+
self.W.inputs.set(X)
111+
self.W.pre.set(X)
112+
self.err.target.set(y_scaled)
126113

127114
def thresholding(self, scale=2):
128115
coef_old = self.coef_ #self.W.weights.value
@@ -135,21 +122,15 @@ def thresholding(self, scale=2):
135122

136123

137124
def fit(self, y, X):
138-
self.circuit.reset()
139-
self.circuit.clamps(y_scaled=y, X=X)
125+
self.reset.run()
126+
self.clamp(y_scaled=y, X=X)
140127

141128
for i in range(self.epochs):
142-
self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
143-
self.circuit.evolve(t=self.T, dt=self.dt)
144-
145-
self.coef_ = np.array(self.W.weights.value)
146-
147-
return self.coef_, self.err.mu.value, self.err.L.value
148-
149-
150-
151-
152-
129+
inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt))
130+
stateManager.state, outputs = self.advance.scan(inputs)
131+
self.evolve.run(t=self.T, dt=self.dt)
153132

133+
self.coef_ = np.array(self.W.weights.get())
154134

135+
return self.coef_, self.err.mu.get(), self.err.L.get()
155136

0 commit comments

Comments
 (0)