Skip to content

Commit 59bc743

Browse files
authored
Merge pull request #230 from JuliaAI/dev
For a 0.8.9 release
2 parents 98050c9 + 59b49bc commit 59bc743

File tree

8 files changed

+122
-10
lines changed

8 files changed

+122
-10
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
version:
22-
- '1.6'
22+
- 'lts'
2323
- '1'
2424
os:
2525
- ubuntu-latest
2626
arch:
2727
- x64
2828
steps:
2929
- uses: actions/checkout@v2
30-
- uses: julia-actions/setup-julia@v1
30+
- uses: julia-actions/setup-julia@v2
3131
with:
3232
version: ${{ matrix.version }}
3333
arch: ${{ matrix.arch }}
34-
- uses: actions/cache@v1
34+
- uses: julia-actions/cache@v2
3535
env:
3636
cache-name: cache-artifacts
3737
with:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.8.8"
4+
version = "0.8.9"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"

src/plotrecipes.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
const MAX_AXIS_LABEL_WIDTH = 20
2+
13
@recipe function f(mach::MLJBase.Machine{<:EitherTunedModel})
24
rep = report(mach)
35
measurement = repr(rep.best_history_entry.measure[1])
46
r = rep.plotting
57
z = r.measurements
68
X = r.parameter_values
7-
guides = r.parameter_names
9+
guides = map(r.parameter_names) do name
10+
trim(name, MAX_AXIS_LABEL_WIDTH)
11+
end
812
scales = r.parameter_scales
913
n = size(X, 2)
1014
indices = LinearIndices((n, n))'

src/tuned_models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
const ERR_SPECIFY_MODEL = ArgumentError(
44
"You need to specify `model=...`, unless `tuning=Explicit()`. ")
55
const ERR_SPECIFY_RANGE = ArgumentError(
6-
"You need to specify `range=...`, unless `tuning=Explicit()` and "*
6+
"You need to specify `range=...`, unless `tuning=Explicit()` "*
77
"and `models=...` is specified instead. ")
88
const ERR_SPECIFY_RANGE_OR_MODELS = ArgumentError(
99
"No `model` specified. Either specify an explicit iterator "*

src/utilities.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,22 @@ signature(measure) =
3939
else
4040
0
4141
end
42+
43+
# function to trim a string like "transformed_target_model_deterministic.model.K" to `N`
44+
# characters. For example, if `N=20`, return `…model.K`. Used in plotrecipes.jl.
45+
function trim(str, N)
46+
n = length(str)
47+
n <= N && return str
48+
fits = false
49+
parts = split(str, ".") |> reverse
50+
# removes parts until what remains fits, with room for ellipsis (1 character), or if
51+
# there is only one part left:
52+
while !fits && length(parts) > 1
53+
removed = pop!(parts)
54+
n -= length(removed) + 1 # the `1` is for the dot, `.`
55+
if n < N
56+
fits = true
57+
end
58+
end
59+
""*join(reverse(parts), ".")
60+
end

test/density_estimation.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using Test
2+
using MLJTuning
3+
using MLJBase
4+
using StatisticalMeasures
5+
using StableRNGs
6+
import MLJModelInterface
7+
import StatisticalMeasures: CategoricalDistributions, Distributions
8+
9+
10+
# We define a density estimator to fit a `UnivariateFinite` distribution to some
11+
# Categorical data, with a Laplace smoothing option, α.
12+
13+
mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic
14+
alpha::Float64
15+
end
16+
UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha)
17+
18+
function MLJModelInterface.fit(model::UnivariateFiniteFitter,
19+
verbosity, X, y)
20+
21+
α = model.alpha
22+
N = length(y)
23+
_classes = classes(y)
24+
d = length(_classes)
25+
26+
frequency_given_class = Distributions.countmap(y)
27+
prob_given_class =
28+
Dict(c => (get(frequency_given_class, c, 0) + α)/(N + α*d) for c in _classes)
29+
30+
fitresult = CategoricalDistributions.UnivariateFinite(prob_given_class)
31+
32+
report = (params=Distributions.params(fitresult),)
33+
cache = nothing
34+
35+
verbosity > 0 && @info "Fitted a $fitresult"
36+
37+
return fitresult, cache, report
38+
end
39+
40+
MLJModelInterface.predict(model::UnivariateFiniteFitter,
41+
fitresult,
42+
X) = fitresult
43+
44+
45+
MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) =
46+
Nothing
47+
MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) =
48+
AbstractVector{<:Finite}
49+
50+
# This test will fail if MLJ test dependency MLJBase is < 1.11
51+
@testset "tuning for density estimators" begin
52+
y = coerce(collect("abbabbc"), Multiclass)
53+
X = nothing
54+
55+
train, test = partition(eachindex(y), 3/7)
56+
# For above train-test split, hand calculation determines, when optimizing against
57+
# log loss, that:
58+
best_alpha = 2.0
59+
best_loss = (4log(9) - log(3) - 2log(4) - log(2))/4
60+
61+
model = UnivariateFiniteFitter(alpha=0)
62+
r = range(model, :alpha, values=[0.1, 1, 1.5, 2, 2.5, 10])
63+
tmodel = TunedModel(
64+
model,
65+
tuning=Grid(shuffle=false),
66+
range=r,
67+
resampling=[(train, test),],
68+
measure=log_loss,
69+
compact_history=false,
70+
)
71+
72+
mach = machine(tmodel, X, y)
73+
fit!(mach, verbosity=0)
74+
best = report(mach).best_history_entry
75+
@test best.model.alpha == best_alpha
76+
@test best.evaluation.measurement[1] best_loss
77+
end
78+
79+
true

test/runtests.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ end
6060
@test include("serialization.jl")
6161
end
6262

63-
# @testset "julia bug" begin
64-
# @test include("julia_bug.jl")
65-
# end
66-
63+
@testset "density estimatation" begin
64+
@test include("density_estimation.jl")
65+
end

test/utilities.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,16 @@ end
2222
@test MLJTuning.signature.(measures) == [-1, 0, 1]
2323
end
2424

25+
@testset "trim" begin
26+
str = "some.long.name" # 14 characters
27+
@test MLJTuning.trim(str, 14) == str
28+
@test MLJTuning.trim(str, 13) == "…long.name" # 10 characters
29+
@test MLJTuning.trim(str, 12) == "…long.name"
30+
@test MLJTuning.trim(str, 11) == "…long.name"
31+
@test MLJTuning.trim(str, 10) == "…long.name"
32+
@test MLJTuning.trim(str, 9) == "…name"
33+
@test MLJTuning.trim(str, 1) == "…name" # cannot go any smaller
34+
end
35+
2536
true
2637

0 commit comments

Comments
 (0)