-
Notifications
You must be signed in to change notification settings - Fork 8
Closed
Description
Hi, i was trying to implement an interface between laplaceredux and mlj but i am facing an issue with implementing the probabilistic classifier model. In particular, i have not fully understood how to correctly use UnivariateFinite.
I have imported the packages
using Flux
using Random
using Tables
using LinearAlgebra
using LaplaceRedux
using MLJBase
import MLJModelInterface as MMI
using Distributions: Normal
created the model
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.logitcrossentropy
epochs::Integer = 1000::(_ > 0)
batch_size::Integer= 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
:full::(_ in (:full, :diagonal))
backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
#ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
written a fit function
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
decode = y[1]
y_plain = MLJBase.int(y) .- 1
y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) )
data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
for epoch in 1:m.epochs
Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
m.flux_loss(model(X), y_onehot)
end
end
la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:classification,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
)
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader )
optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
report = (status="success", message="Model fitted successfully")
cache = nothing
return ((la,decode), cache, report)
end
and the predict function
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
predictions = LaplaceRedux.predict(
la,
Xnew;
link_approx=m.link_approx,
ret_distr=false)
return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool= decode, augment=true) for prediction in predictions]
end
but when i run predict i get the error
Warning: Ignoring value of pool
as the specified support defines one already.
and the error is just the last line with UnivariateFinite.
Metadata
Metadata
Assignees
Labels
No labels