Skip to content

Commit 8a5bc68

Browse files
author
Alexander Ororbia
committed
fixed bernoulli error cell
1 parent 3804426 commit 8a5bc68

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
5+
from ngclearn.utils.model_utils import sigmoid, d_sigmoid
56

67
class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
78
"""
@@ -10,13 +11,13 @@ class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
1011
Bernoulli distribution.
1112
1213
| --- Cell Input Compartments: ---
13-
| p - predicted probability of positive trial (takes in external signals)
14+
| p - predicted probability (or logits) of positive trial (takes in external signals)
1415
| target - desired/goal value (takes in external signals)
1516
| modulator - modulation signal (takes in optional external signals)
1617
| mask - binary/gating mask to apply to error neuron calculations
1718
| --- Cell Output Compartments: ---
1819
| L - local loss function embodied by this cell
19-
| dp - derivative of L w.r.t. p
20+
| dp - derivative of L w.r.t. p (or logits, if p = sigmoid(logits))
2021
| dtarget - derivative of L w.r.t. target
2122
2223
Args:
@@ -26,8 +27,11 @@ class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
2627
2728
batch_size: batch size dimension of this cell (Default: 1)
2829
30+
input_logits: if True, treats compartment `p` as logits and will apply a sigmoidal
31+
link, i.e., _p = sigmoid(p), to obtain the param p for Bern(X=1; p)
32+
2933
"""
30-
def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
34+
def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None, **kwargs):
3135
super().__init__(name, **kwargs)
3236

3337
## Layer Size Setup
@@ -39,36 +43,49 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
3943
self.shape = shape
4044
self.n_units = n_units
4145
self.batch_size = batch_size
46+
self.input_logits = input_logits
4247

4348
## Convolution shape setup
4449
self.width = self.height = n_units
4550

4651
## Compartment setup
4752
restVals = jnp.zeros(_shape)
4853
self.L = Compartment(0., display_name="Bernoulli Log likelihood", units="nats") # loss compartment
49-
self.p = Compartment(restVals, display_name="Bernoulli prob for B(X=1; p)") # pos trial prob name. input wire
54+
self.p = Compartment(restVals, display_name="Bernoulli param (prob or logit) for B(X=1; p)") # pos trial prob name. input wire
5055
self.dp = Compartment(restVals) # derivative of positive trial prob
5156
self.target = Compartment(restVals, display_name="Bernoulli data/target variable") # target. input wire
5257
self.dtarget = Compartment(restVals) # derivative target
5358
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
5459
self.mask = Compartment(restVals + 1.0)
5560

5661
@staticmethod
57-
def _advance_state(dt, p, target, modulator, mask): ## compute Bernoulli error cell output
62+
def _advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bernoulli error cell output
5863
# Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
5964
# behavior of the local cost functional
60-
eps = 0.001
61-
_p = jnp.clip(p, eps, 1. - eps) ## to prevent division by 0 later on
65+
eps = 0.0001
66+
_p = p
67+
if input_logits: ## convert from "logits" to probs via sigmoidal link function
68+
_p = sigmoid(p)
69+
_p = jnp.clip(_p, eps, 1. - eps) ## post-process to prevent div by 0
6270
x = target
63-
sum_x = jnp.sum(x) ## Sum^N_{n=1} x_n (n is n-th datapoint)
64-
sum_1mx = jnp.sum(1. - x) ## Sum^N_{n=1} (1 - x_n)
65-
log_p = jnp.log(_p) ## log(p)
66-
log_1mp = jnp.log(1. - _p) ## log(1 - p)
67-
L = log_p * sum_x + log_1mp * sum_1mx ## Bern LL
68-
dL_dp = sum_x/log_p - sum_1mx/log_1mp ## d(Bern LL)/dp
69-
dL_dx = log_p - log_1mp ## d(Bern LL)/dx
70-
71-
dp = dL_dp * modulator * mask ## not sure how mask will apply to a full covariance...
71+
#sum_x = jnp.sum(x) ## Sum^N_{n=1} x_n (n is n-th datapoint)
72+
#sum_1mx = jnp.sum(1. - x) ## Sum^N_{n=1} (1 - x_n)
73+
74+
one_min_p = 1. - _p
75+
one_min_x = 1. - x
76+
log_p = jnp.log(_p) ## ln(p)
77+
log_one_min_p = jnp.log(one_min_p) ## ln(1 - p)
78+
L = jnp.sum(log_p * x + log_one_min_p * one_min_x) ## Bern LL
79+
if input_logits:
80+
dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p)
81+
else:
82+
dL_dp = x/(_p) - one_min_x/one_min_p ## d(Bern LL)/dp
83+
dL_dx = log_p - log_one_min_p ## d(Bern LL)/dx
84+
dp = dL_dp #* d_sigmoid(p)
85+
if input_logits:
86+
dp = dp * d_sigmoid(p)
87+
88+
dp = dL_dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
7289
dtarget = dL_dx * modulator * mask
7390
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
7491
return dp, dtarget, jnp.squeeze(L), mask

0 commit comments

Comments
 (0)