Skip to content
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- 'lts'
- '1'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
- uses: julia-actions/cache@v2
env:
cache-name: cache-artifacts
with:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.8.8"
version = "0.8.9"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
6 changes: 5 additions & 1 deletion src/plotrecipes.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
const MAX_AXIS_LABEL_WIDTH = 20

@recipe function f(mach::MLJBase.Machine{<:EitherTunedModel})
rep = report(mach)
measurement = repr(rep.best_history_entry.measure[1])
r = rep.plotting
z = r.measurements
X = r.parameter_values
guides = r.parameter_names
guides = map(r.parameter_names) do name
trim(name, MAX_AXIS_LABEL_WIDTH)
end
scales = r.parameter_scales
n = size(X, 2)
indices = LinearIndices((n, n))'
Expand Down
2 changes: 1 addition & 1 deletion src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const ERR_SPECIFY_MODEL = ArgumentError(
"You need to specify `model=...`, unless `tuning=Explicit()`. ")
const ERR_SPECIFY_RANGE = ArgumentError(
"You need to specify `range=...`, unless `tuning=Explicit()` and "*
"You need to specify `range=...`, unless `tuning=Explicit()` "*
"and `models=...` is specified instead. ")
const ERR_SPECIFY_RANGE_OR_MODELS = ArgumentError(
"No `model` specified. Either specify an explicit iterator "*
Expand Down
19 changes: 19 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,22 @@ signature(measure) =
else
0
end

# function to trim a string like "transformed_target_model_deterministic.model.K" to `N`
# characters. For example, if `N=20`, return `…model.K`. Used in plotrecipes.jl.
function trim(str, N)
n = length(str)
n <= N && return str
fits = false
parts = split(str, ".") |> reverse
# removes parts until what remains fits, with room for ellipsis (1 character), or if
# there is only one part left:
while !fits && length(parts) > 1
removed = pop!(parts)
n -= length(removed) + 1 # the `1` is for the dot, `.`
if n < N
fits = true
end
end
"…"*join(reverse(parts), ".")
end
79 changes: 79 additions & 0 deletions test/density_estimation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Test
using MLJTuning
using MLJBase
using StatisticalMeasures
using StableRNGs
import MLJModelInterface
import StatisticalMeasures: CategoricalDistributions, Distributions


# We define a density estimator to fit a `UnivariateFinite` distribution to some
# Categorical data, with a Laplace smoothing option, α.

mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic
alpha::Float64
end
UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha)

function MLJModelInterface.fit(model::UnivariateFiniteFitter,
verbosity, X, y)

α = model.alpha
N = length(y)
_classes = classes(y)
d = length(_classes)

frequency_given_class = Distributions.countmap(y)
prob_given_class =
Dict(c => (get(frequency_given_class, c, 0) + α)/(N + α*d) for c in _classes)

fitresult = CategoricalDistributions.UnivariateFinite(prob_given_class)

report = (params=Distributions.params(fitresult),)
cache = nothing

verbosity > 0 && @info "Fitted a $fitresult"

return fitresult, cache, report
end

MLJModelInterface.predict(model::UnivariateFiniteFitter,
fitresult,
X) = fitresult


MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) =
Nothing
MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) =
AbstractVector{<:Finite}

# This test will fail if MLJ test dependency MLJBase is < 1.11
@testset "tuning for density estimators" begin
y = coerce(collect("abbabbc"), Multiclass)
X = nothing

train, test = partition(eachindex(y), 3/7)
# For above train-test split, hand calculation determines, when optimizing against
# log loss, that:
best_alpha = 2.0
best_loss = (4log(9) - log(3) - 2log(4) - log(2))/4

model = UnivariateFiniteFitter(alpha=0)
r = range(model, :alpha, values=[0.1, 1, 1.5, 2, 2.5, 10])
tmodel = TunedModel(
model,
tuning=Grid(shuffle=false),
range=r,
resampling=[(train, test),],
measure=log_loss,
compact_history=false,
)

mach = machine(tmodel, X, y)
fit!(mach, verbosity=0)
best = report(mach).best_history_entry
@test best.model.alpha == best_alpha
@test best.evaluation.measurement[1] ≈ best_loss
end

true
7 changes: 3 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ end
@test include("serialization.jl")
end

# @testset "julia bug" begin
# @test include("julia_bug.jl")
# end

@testset "density estimatation" begin
@test include("density_estimation.jl")
end
11 changes: 11 additions & 0 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,16 @@ end
@test MLJTuning.signature.(measures) == [-1, 0, 1]
end

@testset "trim" begin
str = "some.long.name" # 14 characters
@test MLJTuning.trim(str, 14) == str
@test MLJTuning.trim(str, 13) == "…long.name" # 10 characters
@test MLJTuning.trim(str, 12) == "…long.name"
@test MLJTuning.trim(str, 11) == "…long.name"
@test MLJTuning.trim(str, 10) == "…long.name"
@test MLJTuning.trim(str, 9) == "…name"
@test MLJTuning.trim(str, 1) == "…name" # cannot go any smaller
end

true

Loading