-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Hi - is there a method to get shapley values for classification problems? The code I tried is below:
RFC = @load RandomForestClassifier pkg="DecisionTree"
rfc_model = RFC()
rf_machine = machine(rfc_model, X, y)
MLJ.fit!(rf_machine)
function predict_function(model, data)
data_pred = DataFrame(y_pred = MLJ.predict(model, data))
return data_pred
end
explain = copy(X[1:50, :])
reference = copy(X)
sample_size = 60 # Number of monte carlo samples
data_shap = ShapML.shap(explain=explain, reference=reference, model=rf_machine,
predict_function=predict_function, sample_size=sample_size, seed=1)
and I am getting the following error:
ERROR: LoadError: TypeError: in LocationScale, in T, expected T<:Real, got Type{Any}
Stacktrace:
[1] Distributions.LocationScale(μ::Float64, σ::Float64, ρ::UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64}; check_args::Bool)
@ Distributions ~/.julia/packages/Distributions/jEqbk/src/univariate/locationscale.jl:50
[2] Distributions.LocationScale(μ::Float64, σ::Float64, ρ::UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64})
@ Distributions ~/.julia/packages/Distributions/jEqbk/src/univariate/locationscale.jl:47
[3] *(x::Float64, d::UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64})
@ Distributions ~/.julia/packages/Distributions/jEqbk/src/univariate/locationscale.jl:126
[4] /(d::UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64}, x::Int64)
@ Distributions ~/.julia/packages/Distributions/jEqbk/src/univariate/locationscale.jl:129
[5] _mean(f::typeof(identity), A::Vector{UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64}}, dims::Colon)
@ Statistics /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:176
[6] mean(A::Vector{UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64}}; dims::Function)
@ Statistics /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:164
[7] mean(A::Vector{UnivariateFinite{OrderedFactor{2}, Int64, UInt32, Float64}})
@ Statistics /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:164
[8] _predict(; reference::DataFrame, data_predict::DataFrame, model::Machine{MLJDecisionTreeInterface.RandomForestClassifier, true}, predict_function::typeof(predict_function), n_features::Int64, n_target_features::Int64, n_instances_explain::Int64, sample_size::Int64, precision::Nothing, chunk::Bool, reconcile_instance::Bool, explain::DataFrame)
@ ShapML ~/.julia/packages/ShapML/QMund/src/predict.jl:30
[9] shap(; explain::DataFrame, reference::DataFrame, model::Machine{MLJDecisionTreeInterface.RandomForestClassifier, true}, predict_function::Function, target_features::Nothing, sample_size::Int64, parallel::Nothing, seed::Int64, precision::Nothing, chunk::Bool, reconcile_instance::Bool)
@ ShapML ~/.julia/packages/ShapML/QMund/src/ShapML.jl:168
[10] top-level scope
@ Untitled-1:21
in expression starting at Untitled-1:21