Skip to content

Commit 97122c3

Browse files
authored
Merge pull request #40 from JuliaAI/dev
For a 0.4.2 release
2 parents af03f16 + f514e18 commit 97122c3

File tree

7 files changed

+97
-33
lines changed

7 files changed

+97
-33
lines changed

.github/workflows/CI.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ jobs:
4545
${{ runner.os }}-
4646
- uses: julia-actions/julia-buildpkg@latest
4747
- uses: julia-actions/julia-runtest@latest
48-
- uses: julia-actions/julia-uploadcodecov@v0.1
49-
continue-on-error: true
48+
env:
49+
JULIA_NUM_THREADS: 2
50+
- uses: julia-actions/julia-processcoverage@v1
51+
- uses: codecov/codecov-action@v3
52+
with:
53+
file: lcov.info
5054

5155

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJEnsembles"
22
uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
A package to create bagged homogeneous ensembles of
66
machine learning models using the
7-
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) machine
7+
[MLJ](https://JuliaAI.github.io/MLJ.jl/dev/) machine
88
learning framework.
99

1010
For combining models in more general ways, see the [Composing
11-
Models](https://alan-turing-institute.github.io/MLJ.jl/dev/composing_models/#Composing-Models)
11+
Models](https://JuliaAI.github.io/MLJ.jl/dev/composing_models/#Composing-Models)
1212
section of the MLJ manual.
1313

1414

@@ -34,16 +34,16 @@ using MLJBase, MLJEnsembles
3434
In this case you will also need to load code defining an atomic model
3535
to ensemble. The easiest way to do this is run `Pkg.add("MLJModels");
3636
using MLJModels` and use the `@load` macro. See the [Loading Model
37-
Code](https://alan-turing-institute.github.io/MLJ.jl/dev/loading_model_code/)
37+
Code](https://JuliaAI.github.io/MLJ.jl/dev/loading_model_code/)
3838
of the MLJ manual for this and other possibilities.
3939

4040

4141
## Sample usage
4242

43-
See [Data Science Tutorials](https://alan-turing-institute.github.io/DataScienceTutorials.jl/getting-started/ensembles/).
43+
See [Data Science Tutorials](https://JuliaAI.github.io/DataScienceTutorials.jl/getting-started/ensembles/).
4444

4545

4646
## Documentation
4747

48-
See the [MLJ manual](https://alan-turing-institute.github.io/MLJ.jl/dev/homogeneous_ensembles/#Homogeneous-Ensembles).
48+
See the [MLJ manual](https://JuliaAI.github.io/MLJ.jl/dev/homogeneous_ensembles/#Homogeneous-Ensembles).
4949

src/ensembles.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,9 @@ end
412412
# Random._GLOBAL_RNG() and Random.default_rng() are threadsafe by default_rng
413413
# as they have thread local state from julia >=1.3<=1.6 and task local state Julia >=1.7
414414
threadsafe_rng(rng::typeof(Random.default_rng())) = rng
415-
threadsafe_rng(rng::Random._GLOBAL_RNG) = rng
415+
if isdefined(Random, :_GLOBAL_RNG)
416+
threadsafe_rng(rng::Random._GLOBAL_RNG) = rng
417+
end
416418
threadsafe_rng(rng) = deepcopy(rng)
417419

418420
function _fit(res::CPUThreads, func, verbosity, stuff)

test/ensembles.jl

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
module TestEnsembles
2-
3-
using Test
4-
using Random
5-
using StableRNGs
6-
using MLJEnsembles
7-
using MLJBase
8-
using ..Models
9-
using CategoricalArrays
10-
import Distributions
11-
using StatisticalMeasures
12-
131
## HELPER FUNCTIONS
142

153
@test MLJEnsembles._reducer([1, 2], [3, ]) == [1, 2, 3]
@@ -187,10 +175,10 @@ predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))
187175

188176
@testset "further test of sample weights" begin
189177
## Note: This testset also indirectly tests for compatibility with the data-front end
190-
# implemented by `KNNClassifier` as calls to `fit`/`predict` on an `Ensemble` model
178+
# implemented by `KNNClassifier` as calls to `fit`/`predict` on an `Ensemble` model
191179
# with `atom=KNNClassifier` would error if the ensemble implementation doesn't handle
192180
# data front-end conversions properly.
193-
181+
194182
rng = StableRNG(123)
195183
N = 20
196184
X = (x = rand(rng, 3N), );
@@ -224,18 +212,18 @@ predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))
224212
end
225213

226214

227-
## MACHINE TEST
215+
## MACHINE TEST
228216
## (INCLUDES TEST OF UPDATE.
229-
## ALSO INCLUDES COMPATIBILITY TESTS FOR ENSEMBLES WITH ATOM MODELS HAVING A
217+
## ALSO INCLUDES COMPATIBILITY TESTS FOR ENSEMBLES WITH ATOM MODELS HAVING A
230218
## DIFFERENT DATA FRONT-END SEE #16)
231219

232-
@testset "machine tests" begin
220+
@testset_accelerated "machine tests" acceleration begin
233221
N =100
234222
X = (x1=rand(N), x2=rand(N), x3=rand(N))
235223
y = 2X.x1 - X.x2 + 0.05*rand(N)
236224

237225
atom = KNNRegressor(K=7)
238-
ensemble_model = EnsembleModel(model=atom)
226+
ensemble_model = EnsembleModel(; model=atom, acceleration)
239227
ensemble = machine(ensemble_model, X, y)
240228
train, test = partition(eachindex(y), 0.7)
241229
fit!(ensemble, rows=train, verbosity=0)
@@ -264,15 +252,13 @@ end
264252
atom;
265253
bagging_fraction=0.6,
266254
rng=123,
267-
out_of_bag_measure = [log_loss, brier_score]
255+
out_of_bag_measure = [log_loss, brier_score],
256+
acceleration,
268257
)
269258
ensemble = machine(ensemble_model, X_, y_)
270259
fit!(ensemble)
271260
@test length(ensemble.fitresult.ensemble) == ensemble_model.n
272261

273262
end
274263

275-
276-
end
277-
278264
true

test/runtests.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
1-
include("_models.jl")
1+
using Distributed
2+
# Thanks to https://stackoverflow.com/a/70895939/5056635 for the exeflags tip.
3+
addprocs(; exeflags="--project=$(Base.active_project())")
4+
5+
@info "nprocs() = $(nprocs())"
6+
import .Threads
7+
@info "nthreads() = $(Threads.nthreads())"
8+
9+
include("test_utilities.jl")
10+
include_everywhere("_models.jl")
11+
12+
@everywhere begin
13+
using Test
14+
using Random
15+
using StableRNGs
16+
using MLJEnsembles
17+
using MLJBase
18+
using ..Models
19+
using CategoricalArrays
20+
import Distributions
21+
using StatisticalMeasures
22+
import Distributed
23+
end
24+
225
include("ensembles.jl")
326
include("serialization.jl")
4-

test/test_utilities.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using Test
2+
3+
using ComputationalResources
4+
5+
macro testset_accelerated(name::String, var, ex)
6+
testset_accelerated(name, var, ex)
7+
end
8+
macro testset_accelerated(name::String, var, opts::Expr, ex)
9+
testset_accelerated(name, var, ex; eval(opts)...)
10+
end
11+
function testset_accelerated(name::String, var, ex; exclude=[])
12+
final_ex = quote
13+
local $var = CPU1()
14+
@testset $name $ex
15+
end
16+
17+
resources = AbstractResource[CPUProcesses(), CPUThreads()]
18+
19+
for res in resources
20+
if any(x->typeof(res)<:x, exclude)
21+
push!(final_ex.args, quote
22+
local $var = $res
23+
@testset $(name*" ($(typeof(res).name))") begin
24+
@test_broken false
25+
end
26+
end)
27+
else
28+
push!(final_ex.args, quote
29+
local $var = $res
30+
@testset $(name*" ($(typeof(res).name))") $ex
31+
end)
32+
end
33+
end
34+
# preserve outer location if possible
35+
if ex isa Expr && ex.head === :block && !isempty(ex.args) &&
36+
ex.args[1] isa LineNumberNode
37+
final_ex = Expr(:block, ex.args[1], final_ex)
38+
end
39+
return esc(final_ex)
40+
end
41+
42+
function include_everywhere(filepath)
43+
include(filepath) # Load on Node 1 first, triggering any precompile
44+
if nprocs() > 1
45+
fullpath = joinpath(@__DIR__, filepath)
46+
@sync for p in workers()
47+
@async remotecall_wait(include, p, fullpath)
48+
end
49+
end
50+
end

0 commit comments

Comments
 (0)