Skip to content

question on how to correctly implement an interface to a probabilistic classifier. #211

@pasq-cat

Description

@pasq-cat

Hi, i was trying to implement an interface between laplaceredux and mlj but i am facing an issue with implementing the probabilistic classifier model. In particular, i have not fully understood how to correctly use UnivariateFinite.

I have imported the packages

using Flux
using Random
using Tables
using LinearAlgebra
using LaplaceRedux
using MLJBase
import MLJModelInterface as MMI
using Distributions: Normal

created the model

MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic

    flux_model::Flux.Chain = nothing
    flux_loss = Flux.Losses.logitcrossentropy
    epochs::Integer = 1000::(_ > 0)
    batch_size::Integer= 32::(_ > 0)
    subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
    subnetwork_indices = nothing
    hessian_structure::Union{HessianStructure,Symbol,String} =
        :full::(_ in (:full, :diagonal))
    backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
    σ::Float64 = 1.0
    μ₀::Float64 = 0.0
    P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
    #ret_distr::Bool = false::(_ in (true, false))
    fit_prior_nsteps::Int = 100::(_ > 0)
    link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end

written a fit function

function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
    X = MLJBase.matrix(X) |> permutedims
    decode = y[1]
    y_plain   = MLJBase.int(y) .- 1 
    y_onehot = Flux.onehotbatch(y_plain,  unique(y_plain) )
    data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
    opt_state = Flux.setup(Adam(), m.flux_model)

    for epoch in 1:m.epochs
        Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
            m.flux_loss(model(X), y_onehot)
        
        end
      end


    la = LaplaceRedux.Laplace(
        m.flux_model;
        likelihood=:classification,
        subset_of_weights=m.subset_of_weights,
        subnetwork_indices=m.subnetwork_indices,
        hessian_structure=m.hessian_structure,
        backend=m.backend,
        σ=m.σ,
        μ₀=m.μ₀,
        P₀=m.P₀,
    )

    # fit the Laplace model:
    LaplaceRedux.fit!(la, data_loader )
    optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)

    report = (status="success", message="Model fitted successfully")
    cache     = nothing
    return ((la,decode), cache, report)
end

and the predict function

function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
    la = fitresult
    Xnew = MLJBase.matrix(Xnew) |> permutedims
    predictions = LaplaceRedux.predict(
        la,
        Xnew;
        link_approx=m.link_approx,
        ret_distr=false)
    return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool= decode, augment=true) for prediction in predictions]
end

but when i run predict i get the error
Warning: Ignoring value of pool as the specified support defines one already.
and the error is just the last line with UnivariateFinite.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions