Skip to content

Question on the use of the Update! method and is_same_except() #212

@pasq-cat

Description

@pasq-cat

Hi, i was trying to implement the update method for laplaceredux but I am having a problem.

this is the model

MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
    model::Flux.Chain = nothing
    flux_loss = Flux.Losses.mse
    optimiser = Adam()
    epochs::Integer = 1000::(_ > 0)
    batch_size::Integer = 32::(_ > 0)
    subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
    subnetwork_indices = nothing
    hessian_structure::Union{HessianStructure,Symbol,String} =
        :full::(_ in (:full, :diagonal))
    backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
    σ::Float64 = 1.0
    μ₀::Float64 = 0.0
    P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
    fit_prior_nsteps::Int = 100::(_ > 0)
end

this is the fit function that i have written

function MMI.fit(m::LaplaceRegressor, verbosity, X, y)

    #X = MLJBase.matrix(X) |> permutedims
    #y = reshape(y, 1, :)

    if Tables.istable(X)
        X = Tables.matrix(X)|>permutedims
    end

    # Reshape y if necessary
    y = reshape(y, 1, :)

    data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
    opt_state = Flux.setup(m.optimiser, m.model)
    loss_history=[]
    push!(loss_history, m.flux_loss(m.model(X), y ))

    for epoch in 1:(m.epochs)

        loss_per_epoch= 0.0


        for (X_batch, y_batch) in data_loader
            # Forward pass: compute predictions
            y_pred = m.model(X_batch)

            # Compute loss
            loss = m.flux_loss(y_pred, y_batch)

            # Compute gradients 
            grads = gradient(m.model) do model
                # Recompute predictions inside gradient context
                y_pred = model(X_batch)
                m.flux_loss(y_pred, y_batch)
            end
            
            # Update parameters using the optimizer and computed gradients
            Flux.Optimise.update!(opt_state ,m.model , grads[1])

            # Accumulate the loss for this batch
            loss_per_epoch += sum(loss)  # Summing the batch loss
            
        end

        push!(loss_history,loss_per_epoch )

        # Print loss every 100 epochs if verbosity is 1 or more
        if verbosity >= 1 && epoch % 100 == 0
            println("Epoch $epoch: Loss: $loss_per_epoch ")
        end
    end

    la = LaplaceRedux.Laplace(
        m.model;
        likelihood=:regression,
        subset_of_weights=m.subset_of_weights,
        subnetwork_indices=m.subnetwork_indices,
        hessian_structure=m.hessian_structure,
        backend=m.backend,
        σ=m.σ,
        μ₀=m.μ₀,
        P₀=m.P₀,
    )

    # fit the Laplace model:
    LaplaceRedux.fit!(la, data_loader)
    optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)

    fitresult = la
    report = (loss_history = loss_history,)
    cache = (deepcopy(m),opt_state, loss_history)
    return fitresult, cache, report
end

and now follows the incomplete update function that i was trying. I have removed the loop part since it's not important.

function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)

println(" running MMI:update")

old_model = old_cache[1]

if Tables.istable(X)
    X = Tables.matrix(X)|>permutedims
end

# Reshape y if necessary
y = reshape(y, 1, :)


println(MMI.is_same_except(m, old_model, :epochs))



cache=()
report=()
return old_fitresult, cache, report
end



the issue is that if i try to rerun the model by changing only the number of epochs is_same_except still gives me

false

even though :epochs is listed as exception

using MLJ
flux_model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)

X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) 
MLJBase.fit!(mach)



model.epochs=2000

MLJBase.fit!(mach)

so what is the correct way to implement is_same_except? thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions