22from ngclearn .components .jaxComponent import JaxComponent
33from jax import numpy as jnp , jit
44from ngclearn .utils import tensorstats
5+ from ngclearn .utils .model_utils import sigmoid , d_sigmoid
56
67class BernoulliErrorCell (JaxComponent ): ## Rate-coded/real-valued error unit/cell
78 """
@@ -10,13 +11,13 @@ class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
1011 Bernoulli distribution.
1112
1213 | --- Cell Input Compartments: ---
13- | p - predicted probability of positive trial (takes in external signals)
14+ | p - predicted probability (or logits) of positive trial (takes in external signals)
1415 | target - desired/goal value (takes in external signals)
1516 | modulator - modulation signal (takes in optional external signals)
1617 | mask - binary/gating mask to apply to error neuron calculations
1718 | --- Cell Output Compartments: ---
1819 | L - local loss function embodied by this cell
19- | dp - derivative of L w.r.t. p
20+ | dp - derivative of L w.r.t. p (or logits, if p = sigmoid(logits))
2021 | dtarget - derivative of L w.r.t. target
2122
2223 Args:
@@ -26,8 +27,11 @@ class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
2627
2728 batch_size: batch size dimension of this cell (Default: 1)
2829
30+ input_logits: if True, treats compartment `p` as logits and will apply a sigmoidal
31+ link, i.e., _p = sigmoid(p), to obtain the param p for Bern(X=1; p)
32+
2933 """
30- def __init__ (self , name , n_units , batch_size = 1 , shape = None , ** kwargs ):
34+ def __init__ (self , name , n_units , batch_size = 1 , input_logits = False , shape = None , ** kwargs ):
3135 super ().__init__ (name , ** kwargs )
3236
3337 ## Layer Size Setup
@@ -39,36 +43,49 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
3943 self .shape = shape
4044 self .n_units = n_units
4145 self .batch_size = batch_size
46+ self .input_logits = input_logits
4247
4348 ## Convolution shape setup
4449 self .width = self .height = n_units
4550
4651 ## Compartment setup
4752 restVals = jnp .zeros (_shape )
4853 self .L = Compartment (0. , display_name = "Bernoulli Log likelihood" , units = "nats" ) # loss compartment
49- self .p = Compartment (restVals , display_name = "Bernoulli prob for B(X=1; p)" ) # pos trial prob name. input wire
54+ self .p = Compartment (restVals , display_name = "Bernoulli param ( prob or logit) for B(X=1; p)" ) # pos trial prob name. input wire
5055 self .dp = Compartment (restVals ) # derivative of positive trial prob
5156 self .target = Compartment (restVals , display_name = "Bernoulli data/target variable" ) # target. input wire
5257 self .dtarget = Compartment (restVals ) # derivative target
5358 self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
5459 self .mask = Compartment (restVals + 1.0 )
5560
5661 @staticmethod
57- def _advance_state (dt , p , target , modulator , mask ): ## compute Bernoulli error cell output
62+ def _advance_state (dt , p , target , modulator , mask , input_logits ): ## compute Bernoulli error cell output
5863 # Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
5964 # behavior of the local cost functional
60- eps = 0.001
61- _p = jnp .clip (p , eps , 1. - eps ) ## to prevent division by 0 later on
65+ eps = 0.0001
66+ _p = p
67+ if input_logits : ## convert from "logits" to probs via sigmoidal link function
68+ _p = sigmoid (p )
69+ _p = jnp .clip (_p , eps , 1. - eps ) ## post-process to prevent div by 0
6270 x = target
63- sum_x = jnp .sum (x ) ## Sum^N_{n=1} x_n (n is n-th datapoint)
64- sum_1mx = jnp .sum (1. - x ) ## Sum^N_{n=1} (1 - x_n)
65- log_p = jnp .log (_p ) ## log(p)
66- log_1mp = jnp .log (1. - _p ) ## log(1 - p)
67- L = log_p * sum_x + log_1mp * sum_1mx ## Bern LL
68- dL_dp = sum_x / log_p - sum_1mx / log_1mp ## d(Bern LL)/dp
69- dL_dx = log_p - log_1mp ## d(Bern LL)/dx
70-
71- dp = dL_dp * modulator * mask ## not sure how mask will apply to a full covariance...
71+ #sum_x = jnp.sum(x) ## Sum^N_{n=1} x_n (n is n-th datapoint)
72+ #sum_1mx = jnp.sum(1. - x) ## Sum^N_{n=1} (1 - x_n)
73+
74+ one_min_p = 1. - _p
75+ one_min_x = 1. - x
76+ log_p = jnp .log (_p ) ## ln(p)
77+ log_one_min_p = jnp .log (one_min_p ) ## ln(1 - p)
78+ L = jnp .sum (log_p * x + log_one_min_p * one_min_x ) ## Bern LL
79+ if input_logits :
80+ dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p)
81+ else :
82+ dL_dp = x / (_p ) - one_min_x / one_min_p ## d(Bern LL)/dp
83+ dL_dx = log_p - log_one_min_p ## d(Bern LL)/dx
84+ dp = dL_dp #* d_sigmoid(p)
85+ if input_logits :
86+ dp = dp * d_sigmoid (p )
87+
88+ dp = dL_dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
7289 dtarget = dL_dx * modulator * mask
7390 mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
7491 return dp , dtarget , jnp .squeeze (L ), mask
0 commit comments