Skip to content

Commit fab12d2

Browse files
authored
Merge pull request #114 from DilumAluthge/dpa/verbosity-print-solver-when-fitting
If `verbosity > 0`, print the name of the selected solver when `fit!` is called
2 parents 0e19298 + d835cf8 commit fab12d2

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/mlj/interface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ function MMI.fit(m::Union{REG_MODELS...}, verb::Int, X, y)
3636
features = (sch === nothing) ? nothing : sch.names
3737
reg = glr(m)
3838
solver = m.solver === nothing ? _solver(reg, size(Xmatrix)) : m.solver
39+
verb > 0 && @info "Solver: $(solver)"
3940
# get the parameters
4041
θ = fit(reg, Xmatrix, y; solver=solver)
4142
# return
@@ -72,6 +73,7 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
7273
# NOTE: here the number of classes is either 0 or > 2
7374
clf = glr(m, nclasses)
7475
solver = m.solver === nothing ? _solver(clf, size(Xmatrix)) : m.solver
76+
verb > 0 && @info "Solver: $(solver)"
7577
# get the parameters
7678
θ = fit(clf, Xmatrix, yplain, solver=solver)
7779
# return

test/interface/fitpredict.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,10 @@ end
8484
fp = MLJBase.fitted_params(mach)
8585
@test unique(fp.classes) == [1,2,3]
8686
end
87+
88+
@testset "Fitting a classifier with verbosity=1" begin
89+
mdl = LogisticClassifier()
90+
X, y = MLJBase.@load_iris
91+
mach = MLJBase.machine(mdl, X, y)
92+
@test_logs (:info,"Training Machine{LogisticClassifier,…}.") (:info,"Solver: LBFGS()") MLJBase.fit!(mach; verbosity=1)
93+
end

0 commit comments

Comments
 (0)