1
1
using Flux
2
2
using Statistics
3
3
import Zygote
4
+ import Optimisers
4
5
5
6
"""
6
7
PredictiveModel(networks, input_output_map, input_size, output_size)
@@ -198,7 +199,7 @@ function apply_params(model::PredictiveModel, θ)
198
199
end
199
200
200
201
"""
201
- apply_gradient!(model, dCdy, X, optimizer )
202
+ apply_gradient!(model, dCdy, X, rule )
202
203
203
204
Apply a gradient vector to the model parameters.
204
205
@@ -209,17 +210,16 @@ Apply a gradient vector to the model parameters.
209
210
- `model::PredictiveModel`: model to be updated.
210
211
- `dCdy::Vector{<:Real}`: gradient vector.
211
212
- `X::Matrix{<:Real}`: input data.
212
- - `optimizer `: Optimiser to be used .
213
+ - `rule `: Optimisation rule .
213
214
...
214
215
"""
215
216
function apply_gradient! (
216
217
model:: PredictiveModel ,
217
218
dCdy:: Vector{<:Real} ,
218
219
X:: Matrix{<:Real} ,
219
- optimizer ,
220
+ opt_state ,
220
221
)
221
- ps = Flux. params (model. networks)
222
- loss (x, y) = mean (dCdy' model (x))
223
- train_data = [(X' , 0.0 )]
224
- return Flux. train! (loss, ps, train_data, optimizer)
222
+ loss3 (m, X) = mean (dCdy' m (X' ))
223
+ grad = Zygote. gradient (loss3, model, X)[1 ]
224
+ return Optimisers. update! (opt_state, model, grad)
225
225
end
0 commit comments