11from jax import random , numpy as jnp , jit
22from ngclearn .components .jaxComponent import JaxComponent
3- from ngclearn .utils import tensorstats
4- from ngclearn .utils .weight_distribution import initialize_params
3+ from ngclearn .utils .distribution_generator import DistributionGenerator
54from ngcsimlib .logger import info
65
76from ngcsimlib .compartment import Compartment
@@ -58,10 +57,13 @@ def __init__(
5857
5958 if self .weight_init is None :
6059 info (self .name , "is using default weight initializer!" )
61- self .weight_init = {"dist" : "uniform" , "amin" : 0.025 , "amax" : 0.8 }
62- weights = initialize_params (subkeys [0 ], self .weight_init , shape )
60+ # self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
61+ # weights = initialize_params(subkeys[0], self.weight_init, shape)
62+ self .weight_init = DistributionGenerator .uniform (0.025 , 0.8 )
63+ #weights = initialize_params(subkeys[0], self.weight_init, shape)
64+ weights = self .weight_init (shape , subkeys [0 ])
6365
64- if 0. < p_conn < 1. : ## only non-zero and <1 probs allowed
66+ if 0. < p_conn < 1. : ## Modifier/constraint: only non-zero and <1 probs allowed
6567 p_mask = random .bernoulli (subkeys [1 ], p = p_conn , shape = shape )
6668 weights = weights * p_mask ## sparsify matrix
6769
@@ -76,9 +78,10 @@ def __init__(
7678 if self .bias_init is None :
7779 info (self .name , "is using default bias value of zero (no bias "
7880 "kernel provided)!" )
79- self .biases = Compartment (initialize_params (subkeys [2 ], bias_init ,
80- (1 , shape [1 ]))
81- if bias_init else 0.0 )
81+ self .biases = Compartment (self .bias_init ((1 , shape [1 ]), subkeys [2 ]) if bias_init else 0.0 )
82+ # self.biases = Compartment(initialize_params(subkeys[2], bias_init,
83+ # (1, shape[1]))
84+ # if bias_init else 0.0)
8285
8386 @compilable
8487 def advance_state (self ):
0 commit comments