@@ -73,6 +73,9 @@ class BaseStochasticGradient(ArrayStepShared):
73
73
Total size of the training data
74
74
step_size : float
75
75
Step size for the parameter update
76
+ step_size_decay : int
77
+ Step size decay rate. Every `step_size_decay` iteration the step size reduce
78
+ to the half of the previous step size
76
79
model : PyMC Model
77
80
Optional model for sampling step. Defaults to None (taken from context)
78
81
random_seed : int
@@ -98,6 +101,7 @@ def __init__(self,
98
101
batch_size = None ,
99
102
total_size = None ,
100
103
step_size = 1.0 ,
104
+ step_size_decay = 100 ,
101
105
model = None ,
102
106
random_seed = None ,
103
107
minibatches = None ,
@@ -129,6 +133,7 @@ def __init__(self,
129
133
self .random = tt_rng (random_seed )
130
134
131
135
self .step_size = step_size
136
+ self .step_size_decay = step_size_decay
132
137
shared = make_shared_replacements (vars , model )
133
138
self .q_size = int (sum (v .dsize for v in self .vars ))
134
139
@@ -237,7 +242,7 @@ def mk_training_fn(self):
237
242
avg_I = self .avg_I
238
243
t = self .t
239
244
updates = self .updates
240
- step_size = self .step_size
245
+ epsilon = self .step_size / pow ( 2.0 , t // self . step_size_decay )
241
246
random = self .random
242
247
inarray = self .inarray
243
248
gt , dlog_prior = self .dlogp_elemwise , self .dlog_prior
@@ -268,11 +273,11 @@ def mk_training_fn(self):
268
273
# where B_ch is cholesky decomposition of B
269
274
# i.e. B = dot(B_ch, B_ch^T)
270
275
B_ch = tt .slinalg .cholesky (B )
271
- noise_term = tt .dot ((2. * B_ch )/ tt .sqrt (step_size ), \
276
+ noise_term = tt .dot ((2. * B_ch )/ tt .sqrt (epsilon ), \
272
277
random .normal ((q_size ,), dtype = theano .config .floatX ))
273
278
# 9.
274
279
# Inv. Fisher Cov. Matrix
275
- cov_mat = (gamma * I_t * N ) + ((4. / step_size ) * B )
280
+ cov_mat = (gamma * I_t * N ) + ((4. / epsilon ) * B )
276
281
inv_cov_mat = tt .nlinalg .matrix_inverse (cov_mat )
277
282
# Noise Coefficient
278
283
noise_coeff = (dlog_prior + (N * avg_gt ) + noise_term )
0 commit comments