Skip to content

Commit af03f16

Browse files
authored
Merge pull request #37 from JuliaAI/dev
For a 0.4.1 release
2 parents 6ab6c66 + 794335c commit af03f16

File tree

7 files changed

+107
-64
lines changed

7 files changed

+107
-64
lines changed

.github/codecov.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
coverage:
2+
status:
3+
project:
4+
default:
5+
threshold: 0.5%
6+
removed_code_behavior: fully_covered_patch
7+
patch:
8+
default:
9+
target: 80%

.github/workflows/CI-nightly.yml

Lines changed: 0 additions & 48 deletions
This file was deleted.

.github/workflows/CI.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,5 @@ jobs:
4747
- uses: julia-actions/julia-runtest@latest
4848
- uses: julia-actions/julia-uploadcodecov@v0.1
4949
continue-on-error: true
50-
- uses: julia-actions/julia-uploadcoveralls@v0.1
51-
continue-on-error: true
50+
5251

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJEnsembles"
22
uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/ensembles.jl

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -408,31 +408,56 @@ function _fit(res::CPUProcesses, func, verbosity, stuff)
408408
end
409409
end
410410

411+
# Create thread safe version of RNGs.
412+
# Random._GLOBAL_RNG() and Random.default_rng() are threadsafe by default_rng
413+
# as they have thread local state from julia >=1.3<=1.6 and task local state Julia >=1.7
414+
threadsafe_rng(rng::typeof(Random.default_rng())) = rng
415+
threadsafe_rng(rng::Random._GLOBAL_RNG) = rng
416+
threadsafe_rng(rng) = deepcopy(rng)
417+
411418
function _fit(res::CPUThreads, func, verbosity, stuff)
412419
atom, n, n_patterns, n_train, rng, progress_meter, args = stuff
413420
if verbosity > 0
414421
println("Ensemble-building in parallel on $(Threads.nthreads()) threads.")
415422
end
423+
416424
nthreads = Threads.nthreads()
425+
426+
if nthreads == 1
427+
return _fit(CPU1(), func, verbosity, stuff)
428+
end
429+
417430
chunk_size = div(n, nthreads)
418431
left_over = mod(n, nthreads)
419432
resvec = Vector(undef, nthreads) # FIXME: Make this type-stable?
420433

421-
Threads.@threads for i = 1:nthreads
422-
resvec[i] = if i != nworkers()
423-
func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...)
424-
else
425-
func(
426-
atom,
427-
0,
428-
chunk_size + left_over,
429-
n_patterns,
430-
n_train,
431-
rng,
432-
progress_meter,
433-
args...,
434+
@sync begin
435+
for i in 1:nthreads-1
436+
Threads.@spawn(
437+
resvec[i] = func(
438+
atom,
439+
0,
440+
chunk_size,
441+
n_patterns,
442+
n_train,
443+
threadsafe_rng(rng),
444+
progress_meter,
445+
args...
446+
)
434447
)
435448
end
449+
Threads.@spawn(
450+
resvec[nthreads] = func(
451+
atom,
452+
0,
453+
chunk_size + left_over,
454+
n_patterns,
455+
n_train,
456+
threadsafe_rng(rng),
457+
progress_meter,
458+
args...
459+
)
460+
)
436461
end
437462

438463
return reduce(_reducer, resvec)

test/ensembles.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,23 @@ end
256256
@test length(ensemble.fitresult.ensemble) == 5
257257

258258
@test !isnan(predict(ensemble, MLJEnsembles.selectrows(X, test))[1])
259+
260+
# tests using integer rngs (see issue 27)
261+
X_, y_ = @load_iris
262+
atom = KNNClassifier(K = 7)
263+
ensemble_model = EnsembleModel(
264+
atom;
265+
bagging_fraction=0.6,
266+
rng=123,
267+
out_of_bag_measure = [log_loss, brier_score]
268+
)
269+
ensemble = machine(ensemble_model, X_, y_)
270+
fit!(ensemble)
271+
@test length(ensemble.fitresult.ensemble) == ensemble_model.n
272+
259273
end
260274

275+
261276
end
262277

263278
true

test/serialization.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,50 @@ end
5858
@test predict(smach, X) == predict(mach, X)
5959

6060
rm(filename)
61+
end
6162

63+
# define a supervised model with ephemeral `fitresult`, but which overcomes this by
64+
# overloading `save`/`restore`:
65+
thing = []
66+
struct EphemeralRegressor <: Deterministic end
67+
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
68+
# if I serialize/deserialized `thing` then `id` below changes:
69+
id = objectid(thing)
70+
fitresult = (thing, id, mean(y))
71+
return fitresult, nothing, NamedTuple()
72+
end
73+
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
74+
thing, id, μ = fitresult
75+
return id == objectid(thing) ? fill(μ, nrows(X)) :
76+
throw(ErrorException("dead fitresult"))
77+
end
78+
MLJBase.target_scitype(::Type{<:EphemeralRegressor}) = AbstractVector{Continuous}
79+
function MLJBase.save(::EphemeralRegressor, fitresult)
80+
thing, _, μ = fitresult
81+
return (thing, μ)
82+
end
83+
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
84+
thing, μ = serialized_fitresult
85+
id = objectid(thing)
86+
return (thing, id, μ)
87+
end
88+
89+
@testset "serialization for atomic models with non-persistent fitresults" begin
90+
# https://github.com/alan-turing-institute/MLJ.jl/issues/1099
91+
X, y = (; x = rand(10)), fill(42.0, 3)
92+
ensemble = EnsembleModel(
93+
EphemeralRegressor(),
94+
bagging_fraction=0.7,
95+
n=2,
96+
)
97+
mach = machine(ensemble, X, y)
98+
fit!(mach, verbosity=0)
99+
io = IOBuffer()
100+
MLJBase.save(io, mach)
101+
seekstart(io)
102+
mach2 = machine(io)
103+
close(io)
104+
@test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2)
62105
end
63106

64107
end

0 commit comments

Comments
 (0)