-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Hi, i was trying to implement the update method for laplaceredux but I am having a problem.
this is the model
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
model::Flux.Chain = nothing
flux_loss = Flux.Losses.mse
optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
:full::(_ in (:full, :diagonal))
backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
fit_prior_nsteps::Int = 100::(_ > 0)
end
this is the fit function that i have written
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
#X = MLJBase.matrix(X) |> permutedims
#y = reshape(y, 1, :)
if Tables.istable(X)
X = Tables.matrix(X)|>permutedims
end
# Reshape y if necessary
y = reshape(y, 1, :)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
opt_state = Flux.setup(m.optimiser, m.model)
loss_history=[]
push!(loss_history, m.flux_loss(m.model(X), y ))
for epoch in 1:(m.epochs)
loss_per_epoch= 0.0
for (X_batch, y_batch) in data_loader
# Forward pass: compute predictions
y_pred = m.model(X_batch)
# Compute loss
loss = m.flux_loss(y_pred, y_batch)
# Compute gradients
grads = gradient(m.model) do model
# Recompute predictions inside gradient context
y_pred = model(X_batch)
m.flux_loss(y_pred, y_batch)
end
# Update parameters using the optimizer and computed gradients
Flux.Optimise.update!(opt_state ,m.model , grads[1])
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
end
push!(loss_history,loss_per_epoch )
# Print loss every 100 epochs if verbosity is 1 or more
if verbosity >= 1 && epoch % 100 == 0
println("Epoch $epoch: Loss: $loss_per_epoch ")
end
end
la = LaplaceRedux.Laplace(
m.model;
likelihood=:regression,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
)
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
fitresult = la
report = (loss_history = loss_history,)
cache = (deepcopy(m),opt_state, loss_history)
return fitresult, cache, report
end
and now follows the incomplete update function that i was trying. I have removed the loop part since it's not important.
function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)
println(" running MMI:update")
old_model = old_cache[1]
if Tables.istable(X)
X = Tables.matrix(X)|>permutedims
end
# Reshape y if necessary
y = reshape(y, 1, :)
println(MMI.is_same_except(m, old_model, :epochs))
cache=()
report=()
return old_fitresult, cache, report
end
the issue is that if i try to rerun the model by changing only the number of epochs is_same_except still gives me
false
even though :epochs is listed as exception
using MLJ
flux_model = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y)
MLJBase.fit!(mach)
model.epochs=2000
MLJBase.fit!(mach)
so what is the correct way to implement is_same_except? thank you
Metadata
Metadata
Assignees
Labels
No labels