88from ngcsimlib .logger import info
99from ngclearn import compilable #from ngcsimlib.parser import compilable
1010from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
11+ # from ngclearn.utils.weight_distribution import initialize_params
12+
13+
14+ # def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
15+ # sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
16+ # di, dj = sub_shape
17+ # si, sj = sub_stride
18+
19+ # weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
20+ # #weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
21+ # large_weight_init = DistributionGenerator.constant(value=0.)
22+ # weights = large_weight_init(weight_shape, key[2])
23+
24+ # for i in range(n_sub_models):
25+ # start_i = i * di
26+ # end_i = (i + 1) * di + 2 * si
27+ # start_j = i * dj
28+ # end_j = (i + 1) * dj + 2 * sj
29+
30+ # shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
31+
32+ # ## FIXME: this line below might be wonky...
33+ # weights.at[start_i: end_i, start_j: end_j].set( weight_init(shape_, key[2]) )
34+ # # weights[start_i : end_i,
35+ # # start_j : end_j] = initialize_params(key[2], init_kernel=weight_init, shape=shape_, use_numpy=True)
36+ # if si != 0:
37+ # weights.at[:si,:].set(0.) ## FIXME: this setter line might be wonky...
38+ # weights.at[-si:,:].set(0.) ## FIXME: this setter line might be wonky...
39+ # if sj != 0:
40+ # weights.at[:,:sj].set(0.) ## FIXME: this setter line might be wonky...
41+ # weights.at[:, -sj:].set(0.) ## FIXME: this setter line might be wonky...
42+
43+ # return weights
1144
1245def _create_multi_patch_synapses (key , shape , n_sub_models , sub_stride , weight_init ):
1346 sub_shape = (shape [0 ] // n_sub_models , shape [1 ] // n_sub_models )
1447 di , dj = sub_shape
1548 si , sj = sub_stride
1649
1750 weight_shape = ((n_sub_models * di ) + 2 * si , (n_sub_models * dj ) + 2 * sj )
18- #weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
19- large_weight_init = DistributionGenerator .constant (value = 0. )
20- weights = large_weight_init (weight_shape , key [2 ])
51+ # weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
52+ weights = DistributionGenerator .constant (value = 0. )(weight_shape , key [2 ])
2153
2254 for i in range (n_sub_models ):
2355 start_i = i * di
@@ -27,16 +59,19 @@ def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_in
2759
2860 shape_ = (end_i - start_i , end_j - start_j ) # (di + 2 * si, dj + 2 * sj)
2961
30- ## FIXME: this line below might be wonky...
31- weights .at [start_i : end_i , start_j : end_j ].set ( weight_init (shape_ , key [2 ]) )
3262 # weights[start_i : end_i,
33- # start_j : end_j] = initialize_params(key[2], init_kernel=weight_init, shape=shape_, use_numpy=True)
34- if si != 0 :
35- weights .at [:si ,:].set (0. ) ## FIXME: this setter line might be wonky...
36- weights .at [- si :,:].set (0. ) ## FIXME: this setter line might be wonky...
37- if sj != 0 :
38- weights .at [:,:sj ].set (0. ) ## FIXME: this setter line might be wonky...
39- weights .at [:, - sj :].set (0. ) ## FIXME: this setter line might be wonky...
63+ # start_j : end_j] = initialize_params(key[2],
64+ # init_kernel=weight_init,
65+ # shape=shape_,
66+ # use_numpy=True)
67+ weights = weights .at [start_i : end_i ,
68+ start_j : end_j ].set (weight_init (shape_ , key [2 ]))
69+ if si != 0 :
70+ weights = weights .at [:si ,:].set (0. )
71+ weights = weights .at [- si :,:].set (0. )
72+ if sj != 0 :
73+ weights = weights .at [:,:sj ].set (0. )
74+ weights = weights .at [:, - sj :].set (0. )
4075
4176 return weights
4277
0 commit comments